<a href="https://www.kaggle.com/code/tunkhanhbi/4-animals?scriptVersionId=182400815" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

# Image classification using 4 animals Kaggle dataset

View dataset description and leaderboard [here](https://www.kaggle.com/competitions/4-animal-classification/)

## 1. Import modules

In [None]:
import csv
import torch
from torchvision import transforms, models
from torch import nn
from tqdm import tqdm
from glob import glob
from PIL import Image
import numpy as np
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset

In [None]:
# Check GPU compatibility
if torch.cuda.is_available():
    print("GPU is available")
    print("GPU device:", torch.cuda.get_device_name(0))
    print("GPU memory:", round(torch.cuda.get_device_properties(0).total_memory/1024**3),"GB")
else:
    print("GPU is not available. Using CPU")
    
CUDA = torch.cuda.is_available()
device = "cuda" if CUDA else "cpu"


## 2. Process data

In [None]:

# load data
test_data = []
test_id = []
train_data_x = []
train_data_y = []

transform = transforms.Compose([
    transforms.PILToTensor(),
    transforms.Lambda(lambda x: x.to(device)/255),
    transforms.Resize((256,256),antialias=True),
    transforms.CenterCrop((224,224)),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

for file in glob("/kaggle/input/4-animal-classification/test/test/*"):
    # Process image
    img = Image.open(file)
    img = transform(img).to("cpu")
    test_data.append(img)
    
    # Process ID
    id = file.split("/")[-1].split(".")[0]
    test_id.append(int(id))
    
    

print(f"Processed {len(test_data)} image.")

animals = ['cat','deer','dog','horse']

for label, animal in enumerate(animals):
    for file in glob(f"/kaggle/input/4-animal-classification/train/{animal}/*"):
        img = Image.open(file)
        img = transform(img).to("cpu")
        train_data_x.append(img)
        train_data_y.append(label)

print(f"Processed {len(train_data_x)} image.")

In [None]:
# convert to tensor
test_data = torch.tensor(np.array(test_data),dtype=torch.float32)
test_id = torch.tensor(np.array(test_id),dtype=torch.long)
train_data_x = torch.tensor(np.array(train_data_x),dtype=torch.float32)
train_data_y = torch.tensor(np.array(train_data_y),dtype=torch.long)


In [None]:
print("Train data shape:",train_data_x.shape)
print("Train label shape:",train_data_y.shape)
print("Test data shape:",test_data.shape)
print("Test id shape: ",test_id.shape)

In [None]:
# save data 
torch.save(train_data_x, "/kaggle/working/train_data_x.pt")
torch.save(train_data_y, "/kaggle/working/train_data_y.pt")
torch.save(test_data, "/kaggle/working/test_data.pt")
torch.save(test_id,"/kaggle/working/test_id.pt")


## 3. Define architectures

In [None]:
model = models.vgg16(weights='DEFAULT')
model.classifier[6] = nn.Linear(in_features=4096,out_features=4)
print(model.forward)

## 4. Train models

In [None]:
# load train data
train_data_x = torch.load("/kaggle/working/train_data_x.pt").to(device)
train_data_y = torch.load("/kaggle/working/train_data_y.pt").to(device)


In [None]:
# define data loader
batch_size = 32

train_x, val_x, train_y, val_y = train_test_split(train_data_x,train_data_y, test_size=0.25, random_state=42)

class Data(Dataset):
    def __init__(self,data,label):
        self.data = data
        self.label = label
    def __len__(self):
        return self.data.shape[0]
    def __getitem__(self,idx):
        return self.data[idx], self.label[idx]


train_data = Data(train_x,train_y)
val_data = Data(val_x,val_y)

train_loader = DataLoader(train_data,batch_size=batch_size)
val_loader = DataLoader(val_data,batch_size=batch_size)



In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=0.0001)

model = model.to(device)
criterion = criterion.to(device)

In [None]:


min_val_loss=10000000
for epoch in range(10):
    total_loss_train=0
    total_acc_train=0
    
    model.train()
    for x, y in tqdm(train_loader):
        output = model(x.float())

        batch_loss = criterion(output, y)
        total_loss_train += batch_loss.item()

        acc = (output.argmax(dim=1)==y).sum().item()
        total_acc_train += acc

        optimizer.zero_grad()
        batch_loss.backward()
        optimizer.step()

    total_loss_val=0
    total_acc_val=0
    
    model.eval()
    with torch.no_grad():
        for x, y in tqdm(val_loader):
            output = model(x.float())
            batch_loss = criterion(output, y)
            total_loss_val += batch_loss.item()
            
            acc = (output.argmax(dim=1)==y).sum().item()
            total_acc_val += acc

    print(
        f'Epochs: {epoch+1} | Train Loss: {total_loss_train / len(train_x):.3f}\
        | Train Accuracy: {total_acc_train/len(train_x):.3f}\
        | Val Loss: {total_loss_val/len(val_x):.3f}\
        | Val Accuracy:{total_acc_val/len(val_x):.3f}'
    )

    if min_val_loss>total_loss_val/len(val_x):
        min_val_loss = total_loss_val/len(val_x)
        torch.save(model.state_dict(), "/kaggle/working/model.pt")
        print(f"Save model because val loss improve loss {min_val_loss:.3f}")
    
    print("-"*50)

## 5. Evaluate test

In [None]:
# load test data
test_data = torch.load("/kaggle/working/test_data.pt").to(device)
test_id = torch.load("/kaggle/working/test_id.pt")

In [None]:
# create data loader
test_loader = DataLoader(Data(test_id,test_data),batch_size=batch_size)

In [None]:
# load model best state dict
model.load_state_dict(torch.load("/kaggle/working/model.pt"))
model = model.to(device)

In [None]:
with open("/kaggle/working/submission.csv","w") as f:
    csvwriter = csv.writer(f)
    csvwriter.writerow(["Id","Label"])
    model.eval()
    with torch.no_grad():
        for id_list, img in test_loader:
            output = model(img)
            label_list = output.argmax(dim=1)
            for id, label in zip(id_list,label_list):
                csvwriter.writerow([id.item(),label.item()])