#Imports

In [None]:
pip install torchmetrics

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import sys
import os
import gc
import warnings
import random
from copy import deepcopy
import random
import math

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import torchvision

import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import Parameter
from torchmetrics import Accuracy
from torch.utils.data import Dataset, DataLoader
from skimage.transform import resize
from torch.optim import AdamW

from sklearn.model_selection import StratifiedKFold
    
from tqdm.notebook import tqdm

warnings.filterwarnings("ignore")
tqdm.pandas()

In [None]:
PROJECT_DIR = "/content/drive/MyDrive/a"

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

# EDA

In [None]:
classes = ['Red', 'Green', 'Violet', 'White', 'Yellow', 'Brown', 'Black', 'Blue', 'Cyan', 'Grey', 'Orange']
classes_sns = ['Red', 'Green', 'Violet', 'Pink', 'Yellow', 'Brown', 'Black', 'Blue', 'Cyan', 'Grey', 'Orange']
counts = []
for class_name in classes:
    counts.append(len(os.listdir(f"{PROJECT_DIR}/data/train/{class_name}")))

In [None]:
sns.barplot(classes, counts, palette=classes_sns)

# Dataset

In [None]:
classes = ['Red', 'Green', 'Violet', 'White', 'Yellow', 'Brown', 'Black', 'Blue', 'Cyan', 'Grey', 'Orange']
target_encoder = {}
for i in range(len(classes)):
    target_encoder[classes[i]] = i

In [None]:
import os
import pandas as pd
from torchvision.io import read_image


class CarDataset(Dataset):
    def __init__(self, dir, transform=None, target_encoder: dict = None):
        self.transform = transform
        self.target_encoder = target_encoder
        class_names = os.listdir(dir)
        print(class_names)
        self.filenames = []
        self.labels = []

        for class_name in tqdm(class_names):
            class_path = f"{dir}/{class_name}"
            images = os.listdir(class_path)
            for image in images:
                self.filenames.append(f"{dir}/{class_name}/{image}")
                self.labels.append(class_name)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        img_path = self.filenames[idx]
        image = read_image(img_path)
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)
        
        if self.target_encoder:
            label = self.target_encoder[label]

        return image, label

In [None]:
dataset = CarDataset(f"{PROJECT_DIR}/data/train", target_encoder=target_encoder)

In [None]:
len(dataset)

In [None]:
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

transforms=torchvision.transforms.Compose([
    torchvision.transforms.Resize(size=(256,256)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Lambda(lambda a: a / 255),
    torchvision.transforms.Normalize(mean, std)
])

In [None]:
train_ids = [i for i in range(0, 1000)]
eval_ids = [i for i in range(1000, 1300)]

In [None]:
data_path = '/content/drive/MyDrive/a/data'
train_folder = torchvision.datasets.ImageFolder(data_path + '/train', transform=transforms)

In [None]:
train_subsampler = torch.utils.data.Subset(train_folder,  train_ids)
train_loader = torch.utils.data.DataLoader(train_subsampler, batch_size=64, num_workers=1, shuffle=True)
eval_subsampler = torch.utils.data.Subset(train_folder,  eval_ids)
eval_loader = torch.utils.data.DataLoader(eval_subsampler, batch_size=64, num_workers=1, shuffle=False)

# Losses

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, gamma=0, eps=1e-7):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.eps = eps  
        self.ce = torch.nn.CrossEntropyLoss()

    def forward(self, input, target):
        logp = self.ce(input, target)
        p = torch.exp(-logp)
        loss = (1 - p) ** self.gamma * logp
        return loss.mean()

# Loops

In [None]:
accuracy = Accuracy(num_classes=11)

In [None]:
def train_epoch(model, data_loader, loss_function, optimizer, scheduler, device):

    model.train(True)
    model.to(device)
    total = len(data_loader.dataset)  
    epoch_loss, epoch_acc = 0, 0

    for input, target in data_loader:
        input, target = input.to(device), target.to(device) #prepare for train
        optimizer.zero_grad()
        preds = model(input) #model predicts

        loss = loss_function(preds, target)
        loss.backward()
        optimizer.step()
        scheduler.step()

        epoch_loss += loss.item()
        epoch_acc += accuracy(preds, target) #accuaracy

    epoch_acc = epoch_acc / total
    epoch_loss = epoch_loss / total
    
    return epoch_loss, epoch_acc
    
    
def eval_epoch(model, data_loader, loss_function, device):
    model.train(False)
    model.to(device)
    epoch_loss = 0
    epoch_acc = 0
    total = len(data_loader)

    for input, target in data_loader:
        input, target = input.to(device), target.to(device)
        with torch.no_grad():
            preds = model(input)
            loss = loss_function(preds, target)
            epoch_loss += loss.item()
            epoch_acc += accuracy(preds, target)

    epoch_acc = epoch_acc / total
    epoch_loss = epoch_loss / total

    return epoch_loss, epoch_acc

#Model and train

In [None]:
model = torchvision.models.resnet18(weights = "ResNet18_Weights.IMAGENET1K_V1")

optimizer = torch.optim.AdamW(model.parameters(), lr = 0.001)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.8)
loss = FocalLoss()

In [None]:
for name, param in model.named_parameters():
  if not name.startswith("layer4.1"):
    param.requires_grad = False
for name, par in model.named_parameters():
  if par.requires_grad:
       print(name)

In [None]:
model.fc = torch.nn.Linear(512, 11, bias = True)

In [None]:
epochs = 5
for i in tqdm(range(epochs)):
    train_loss, train_accuracy = train_epoch(model, train_loader, loss, optimizer, scheduler, device)
    test_loss, test_accuracy = eval_epoch(model, eval_loader, loss,  device)
    print(f'\n Epoch #{i + 1}\nTrain loss = {train_loss}, Train accuracy = {train_accuracy}')
    print(f'Test loss = {test_loss}, Test accuracy = {test_accuracy}')

# Cross Validation