### Imports

In [None]:
import torch
import pandas as pd
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from torchvision import datasets, models
from torch import nn
from torchvision.transforms import v2
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image
from torch.utils.data import DataLoader, random_split
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter

### Parameters

In [None]:
# loading the configuration file - standard way - use .env file and load_dotenv from python-dotenv module
config_file_path = "config.json"

# Read JSON data into a dictionary
with open(config_file_path, 'r') as file:
    data = json.load(file)

IMAGE_DIR = data["image_dir"]
CSV_DIR = data["csv_dir"]
BATCH_SIZE = 64
EPOCHS = 3
FINE_TUNE_EPOCHS = 30
EARLY_STOP_EPOCHS = 3
save_model = True
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
transforms = v2.Compose([
    v2.Resize(size=(128, 128)),
    v2.RandomResizedCrop(size=(224, 224), antialias=True),
    v2.RandomHorizontalFlip(p=0.5),
    v2.ToDtype(torch.float32),
    #v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  
])

In [None]:
class myImageDataset(Dataset):
    def __init__(self):
        self.df = pd.read_csv(CSV_DIR)
        self.imgs=self.df[['Img_name']]
        self.labels=self.df[['label']]
        self.imgs.reset_index(drop=True)
        self.labels.reset_index(drop=True)

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx): 
        #print(f'{IMAGE_DIR}/{self.imgs.iat[idx, 0]}')
        rawimg = Image.open(f'{IMAGE_DIR}/{self.imgs.iat[idx, 0]}')
        try:
            trans_image= transforms(rawimg)
            numpyimage = np.array(trans_image)
            return numpyimage, self.labels.iat[idx, 0]
        except:
            print(f"{self.imgs.iat[idx, 0]} is corrupted")


In [None]:
dataset = myImageDataset()

total_size = len(dataset)
val_size = int(0.2 * total_size)
train_size = total_size - val_size

train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
validation_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
model = models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features

for param in model.parameters():
    param.requires_grad = False

model.fc = nn.Linear(num_ftrs, 4)
model.to(device)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.01, last_epoch=-1, verbose='deprecated')
loss_fn = torch.nn.CrossEntropyLoss()

### Training Loop

In [None]:
def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.
    last_loss = 0.
    
    for i, data in enumerate(train_dataloader):
        inputs, labels = data
        inputs = inputs.float() / 255.0
        inputs = inputs.to(device)
        inputs = inputs.permute(0, 3, 1, 2)
        labels = labels.to(device)
        print(type(inputs))

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs)
        #print(outputs)

        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        #print(f'loss: {running_loss}')
        if i % 100 == 99:
            last_loss = running_loss / 100 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(train_dataloader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.

    return last_loss

In [None]:
def validate_one_epoch():
    v_correct = 0
    running_vloss = 0.0
    model.eval()

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(validation_dataloader):
            vinputs, vlabels = vdata
            vinputs = vinputs.float() / 255.0
            vinputs = vinputs.to(device)
            vinputs = vinputs.permute(0, 3, 1, 2)
            vlabels = vlabels.to(device)
            
            voutputs = model(vinputs)
            vpredictions = torch.argmax(voutputs, dim=1)
            vloss = loss_fn(voutputs, vlabels)
            running_vloss += vloss
            for v, vprediction in enumerate(vpredictions):
                if vprediction == vlabels[v]:
                    v_correct+=1
        v_accuracy = round(v_correct/(i*BATCH_SIZE)*100, 2)
        print(f'{v_correct}/{i*BATCH_SIZE}')
        print(F'Val Accuracy: {v_accuracy}%')

    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
    
    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    writer.flush()
    
    return avg_vloss, v_accuracy

### Training

In [None]:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/trainer_{}'.format(timestamp))
epoch_number = 0
f_epoch_number = 0

In [None]:
best_vloss = 1000000.

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch_number, writer)
    scheduler.step()

    avg_vloss = validate_one_epoch()

    epoch_number += 1

#### Unfreeze all layers

In [None]:
for param in model.parameters():
    param.requires_grad = True

In [None]:
early_stop_counter = 0

for f_epoch in range(FINE_TUNE_EPOCHS):
    print('EPOCH {}:'.format(f_epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(f_epoch_number, writer)
    scheduler.step()

    avg_vloss, v_accuracy = validate_one_epoch()
    
    print(f'{avg_vloss} vs. {best_vloss}')
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        if save_model == True:
            model_path = 'model_{}_{}_{}'.format(timestamp, f_epoch_number+epoch_number, round(v_accuracy))
            torch.save(model.state_dict(), model_path)
            torch.save(model, '/home/ubuntu/Dataset/'+ model_path)
            early_stop_counter = 0
    elif avg_vloss > best_vloss:
        early_stop_counter += 1
    
    if early_stop_counter >= EARLY_STOP_EPOCHS:
        print("early stopping...")
        model_path = 'model_{}_{}_{}'.format(timestamp, f_epoch_number+epoch_number, round(v_accuracy))
        torch.save(model.state_dict(), model_path)
        torch.save(model, '/home/ubuntu/Dataset/full_models/'+ model_path)
        break
    
    f_epoch_number += 1