# 1. Some Common Imports

In [None]:
import sys
sys.path.append('/dataset')

In [1]:
import torch
import torchvision
from torch import nn
import cv2

import numpy as np 
import pandas as pd
import matplotlib.pyplot as plt 

from sklearn.model_selection import train_test_split
from tqdm import tqdm

# 2. Setup Configurations

In [None]:
CSV_FILE = "./output/cropmeta.csv"
DATA_DIR = "./output/"

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(DEVICE)

EPOCHS = 25
LR = 0.003
BATCH_SIZE = 32
IMAGE_SIZE = 224

In [None]:
df = pd.read_csv(CSV_FILE, header=None)
df.head()

In [None]:
idx = 1

row = df.iloc[idx]

image_path = DATA_DIR + row[0]

image = cv2.imread(image_path)

In [None]:
%matplotlib inline

plt.imshow(image)

In [None]:
train_df, valid_df = train_test_split(df, test_size=0.20, random_state=32)

# 3. Augmentation Functions

albumentation documentation : https://albumentations.ai/docs/

In [None]:
import albumentations as A

In [None]:
def get_train_augs():
  return A.Compose([
      A.Resize(IMAGE_SIZE, IMAGE_SIZE),
      A.HorizontalFlip(p = 0.50),
      A.VerticalFlip(p = 0.50)
  ])

def get_valid_augs():
  return A.Compose([
      A.Resize(IMAGE_SIZE, IMAGE_SIZE)
  ])

# 4. Create Custom Dataset

In [None]:
from torch.utils.data import Dataset

In [None]:
class ClassificationDataset(Dataset):
  def __init__(self, df, augmentations):
    self.df = df
    self.augmentations = augmentations

  def __len__(self):
    return len(self.df)

  def __getitem__(self, idx):

    row = self.df.iloc[idx]

    image_path = DATA_DIR + row[0]

    label = row[1]

    image = cv2.imread(image_path)
    # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # (h, w, c)

    if self.augmentations:
      data = self.augmentations(image = image) 
      image = data["image"] # (h, w, c)

    image = np.transpose(image, (2, 0, 1)).astype(np.float32) # (c, h, w)

    image = torch.Tensor(image) / 255.0

    return image, label

In [None]:
trainset = ClassificationDataset(train_df, get_train_augs())
validset = ClassificationDataset(valid_df, get_valid_augs())

In [None]:
print(f"Size of trainset: {len(trainset)}")
print(f"Size of validset: {len(validset)}")

In [None]:
idx = 21

image, label = trainset[idx]

print(image.shape, label)

# 5. Load dataset into batches

In [None]:
from torch.utils.data import DataLoader

In [None]:
trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle = True)
validloader = DataLoader(validset, batch_size=BATCH_SIZE)

In [None]:
print(f"Total number of batches in trainloader: {len(trainloader)}")
print(f"Total number of batches in validloader: {len(validloader)}")

In [None]:
for images, labels in trainloader:
  print(f"One batch image shape: {images.shape}")
  print(f"One batch label shape {labels.shape}")
  break;

# 6. Create Classification Model

In [None]:
import torchvision
from torch import nn

pretrained_net = torchvision.models.resnet18(pretrained=True)

pretrained_net.fc

In [None]:
finetune_net = torchvision.models.resnet18(pretrained=True)
finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 2)
nn.init.xavier_uniform_(finetune_net.fc.weight);
finetune_net.to(DEVICE)

# 6. Create Train and Validation Function

In [None]:
def train_fn(dataloader, model, optimizer):

  model.train() # Turn ON dropout, batchnorm, ect.

  total_loss = 0.0

  correct = 0

  loss_func = nn.CrossEntropyLoss()

  for images, labels in tqdm(dataloader):

    images = images.to(DEVICE)
    labels = labels.to(DEVICE)

    optimizer.zero_grad()
    outputs = model(images)
    # print(f"Outputs shape: {outputs.shape}")
    
    loss = loss_func(outputs, labels)
    # print(f"Loss shape: {loss.shape}")
    loss.backward()
    optimizer.step()

    total_loss += loss.item()

    correct += sum(outputs.argmax(dim=1) == labels)

  return total_loss / len(dataloader), correct / len(dataloader.dataset)

In [None]:
def eval_fn(dataloader, model):
  
  model.eval() # Turn OFF dropout, batchnorm, etc.

  total_loss = 0.0

  correct = 0

  loss_func = nn.CrossEntropyLoss()

  with torch.no_grad():
    
    for images, labels in tqdm(dataloader):

      images = images.to(DEVICE)
      labels = labels.to(DEVICE)
      outputs = model(images)
      loss = loss_func(outputs, labels)
      total_loss += loss.item()

      correct += sum(outputs.argmax(dim=1) == labels)

    return total_loss / len(dataloader), correct / len(dataloader.dataset)

# 8. Train Model

In [None]:
optimizer = torch.optim.Adam(finetune_net.fc.parameters(), lr = LR)

In [None]:
best_loss = np.Inf

for epoch in range(EPOCHS):
  train_loss, train_accuracy = train_fn(trainloader, finetune_net, optimizer)
  valid_loss, valid_accuracy = eval_fn(validloader, finetune_net)

  if valid_loss < best_loss:
    torch.save(finetune_net.state_dict(), "best_model.pt")
    print("SAVED MODEL")
    best_loss = valid_loss

  print(f"Epoch: {epoch+1} Train Loss: {train_loss} Train Accuracy: {train_accuracy} Valid Loss: {valid_loss} Valid Accuracy: {valid_accuracy}")

# 9. Inference

In [None]:
model = torchvision.models.resnet18()
model.fc = nn.Linear(model.fc.in_features, 2)

idx = 20

model.load_state_dict(torch.load("best_model.pt", map_location=torch.device(DEVICE)))
image, label = validset[idx]

output = model(image.to(DEVICE).unsqueeze(0)) # (c, h, w) -> (b, c, h, w)
print(f"Output shape: {output.shape}")
pred = output.argmax(dim=1).detach().cpu().item()
print(f"pred: {pred}, label: {label}, {pred==label}")

In [None]:
count = 0
num_of_valid_data = len(validloader.dataset)

for images, labels in tqdm(validloader):
  
  images = images.to(DEVICE)
  labels = labels.to(DEVICE)
  outputs = model(images)
  pred = outputs.argmax(dim=1)
  count += sum(pred==labels)
  print(pred==labels)

print(f"Prediction Accuracy: {count / num_of_valid_data}")
print(f"Wrong predictions: {num_of_valid_data - count}")

In [None]:
image = np.transpose(image, (1, 2, 0))
plt.imshow(image)

In [None]:
import pickle

#saving the model
pickle.dump(model, open('model.pkl', 'wb'))

In [5]:
model50 = torchvision.models.resnet50()
model50.fc = nn.Linear(model50.fc.in_features, 2)

model50.load_state_dict(torch.load("best_model_resnet50.pt", map_location="cpu"))


<All keys matched successfully>

In [6]:
import pickle

#saving the model
pickle.dump(model50, open('model50.pkl', 'wb'))