In [1]:
import torchvision.models as models
import torch
from torchvision.transforms import v2
import torchvision.transforms as transforms
import os
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from PIL import Image
from torch.utils.data import DataLoader, WeightedRandomSampler, Dataset
import torchvision.datasets as datasets
import numpy as np
import torch.optim as optim
import torch.nn as nn
import random
import numpy as np
from sklearn.metrics import f1_score, balanced_accuracy_score, roc_auc_score

**Augumentation**

In [2]:
minority_augmentation = transforms.Compose([
    
    transforms.Resize((224, 224)),

    transforms.RandomRotation(degrees=(0, 10)), 

    transforms.RandomApply([transforms.RandomHorizontalFlip()], p=0.5),

    transforms.RandomApply([transforms.RandomVerticalFlip()], p=0.5),

    transforms.ToTensor()
    
])


majority_augmentation = transforms.Compose([

    transforms.Resize((224, 224)),
    
    transforms.ToTensor()
])

In [3]:
class ResNetTransformerClassifier(nn.Module):
    
    def __init__(self, num_classes=4, d_model=512, num_heads=8, num_layers=1):
        
        super(ResNetTransformerClassifier, self).__init__()
        self.resnet = models.resnet50(pretrained=True)
        self.resnet.fc = nn.Identity()  
        self.resnet.fc = nn.Linear(2048, d_model)  
        
        # Transformer Encoder
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=num_heads)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Classification Layer
        self.fc = nn.Linear(d_model, num_classes)

    def forward(self, x):
        
        x = self.resnet(x)  
        x = x.unsqueeze(1)  
        
        x = self.transformer(x)  # Pass through transformer
        x = x.squeeze(1)  # Remove sequence dimension
        
        return self.fc(x) 

In [4]:
dataset_path = '/kaggle/input/trining-dataset/training'
validationDataset_path = '/kaggle/input/validation/validating'

training_dataset = datasets.ImageFolder(root=dataset_path)
validation_dataset = datasets.ImageFolder(root=validationDataset_path)

**Loading the sample**

In [5]:
targets = np.array([label for _, label in training_dataset.samples])

class_counts = np.bincount(targets)

class_weights = 1.0 / torch.tensor(class_counts, dtype=torch.float)


In [6]:
targets, class_counts, class_weights

(array([0, 0, 0, ..., 3, 3, 3]),
 array([ 1154,  2694, 28663,  1162]),
 tensor([8.6655e-04, 3.7120e-04, 3.4888e-05, 8.6059e-04]))

In [7]:
minority_classes = [0, 1, 3]

**loading the dataset**

In [8]:
dataset_path = "./trining-dataset/training"  

class TrainingDataset(Dataset):
    
    def __init__(self, dataset, minority_classes):
        self.dataset = dataset
        self.minority_classes = set(minority_classes) 

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image_path, label = self.dataset.samples[idx]
        image = self.dataset.loader(image_path) 

        if label in self.minority_classes:
            image = minority_augmentation(image)
        else:
            image = majority_augmentation(image)
        
        return image, label

In [9]:
sample_weights = [class_weights[label] for label in targets]

sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)

In [10]:
train_dataset = TrainingDataset(training_dataset, minority_classes)

In [11]:
train_loader = DataLoader(train_dataset, batch_size=16, sampler=sampler)

In [12]:
class ValidationDataset(Dataset):
    
    def __init__(self, dataset, minority_classes):
        self.dataset = dataset
        self.minority_classes = set(minority_classes) 

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image_path, label = self.dataset.samples[idx]
        image = self.dataset.loader(image_path) 

        if label in self.minority_classes:
            image = minority_augmentation(image)
        else:
            image = majority_augmentation(image)
        
        return image, label

**loading validation dataset**

In [13]:
validation_dataset2 = ValidationDataset(validation_dataset,  minority_classes)

In [14]:
valid_loader = DataLoader( validation_dataset2, batch_size=128,  shuffle=False,
                        
    drop_last=False  )

In [15]:
model = ResNetTransformerClassifier()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 173MB/s] 


In [16]:
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-1)

scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

**testing certieria**

In [17]:

class_weights = class_weights / class_weights.sum()  # Normalize to sum to 1

class_weights = class_weights.to(device)

criterion = nn.CrossEntropyLoss(weight=class_weights).to(device)

In [18]:
num_epochs = 50

patience = 40
best_metric = -np.inf  

epochs_no_improve = 0

In [19]:
def train_statistics():
    
    all_preds, all_labels, all_probs = [], [], []
    
    with torch.no_grad():
        
        for images, labels in train_loader:
            
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            probs = torch.softmax(outputs, dim=1).cpu().numpy()
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs)

   
    balanced_acc = balanced_accuracy_score(all_labels, all_preds)
    macro_f1 = f1_score(all_labels, all_preds, average="macro")
    mean_auc = roc_auc_score(all_labels, all_probs, multi_class="ovr", average="macro")
    
    print(f" for train: Epoch  Balanced Acc: {balanced_acc:.4f}, Macro F1: {macro_f1:.4f}, Mean AUC: {mean_auc:.4f}")

In [None]:
last = False
for epoch in range(num_epochs):
    
    model.train()
    running_loss = 0.0
    
    for images, labels in train_loader:
        
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)

        probs = torch.softmax(outputs, dim=1)
        loss = criterion(probs, labels).to(device)
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()

    scheduler.step()

    model.eval()
    train_statistics()
    all_preds, all_labels, all_probs = [], [], []

    with torch.no_grad():
        
        for images, labels in valid_loader:
            
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            probs = torch.softmax(outputs, dim=1).cpu().numpy()
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs)

   
    balanced_acc = balanced_accuracy_score(all_labels, all_preds)
    macro_f1 = f1_score(all_labels, all_preds, average="macro")
    mean_auc = roc_auc_score(all_labels, all_probs, multi_class="ovr", average="macro")
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}, "
          f"Balanced Acc: {balanced_acc:.4f}, Macro F1: {macro_f1:.4f}, Mean AUC: {mean_auc:.4f}")

  
    if mean_auc > best_metric:
        
        best_metric = mean_auc
        epochs_no_improve = 0
        torch.save(model.state_dict(), "best_model.pth")  
        last = True
        
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print(f"Early stopping triggered at epoch {epoch+1}")
            break
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")

if last == False:
    torch.save(model.state_dict(), "best_model.pth")

 for train: Epoch  Balanced Acc: 0.2500, Macro F1: 0.1014, Mean AUC: 0.5081
Epoch [1/50], Loss: 1.3404, Balanced Acc: 0.2500, Macro F1: 0.0166, Mean AUC: 0.5696
Epoch [1/50], Loss: 1.3404
 for train: Epoch  Balanced Acc: 0.2500, Macro F1: 0.0995, Mean AUC: 0.5256
Epoch [2/50], Loss: 1.3020, Balanced Acc: 0.2500, Macro F1: 0.0167, Mean AUC: 0.4722
Epoch [2/50], Loss: 1.3020
 for train: Epoch  Balanced Acc: 0.2500, Macro F1: 0.0997, Mean AUC: 0.6453
Epoch [3/50], Loss: 1.2702, Balanced Acc: 0.2500, Macro F1: 0.0167, Mean AUC: 0.4752
Epoch [3/50], Loss: 1.2702
 for train: Epoch  Balanced Acc: 0.2500, Macro F1: 0.0997, Mean AUC: 0.5050
Epoch [4/50], Loss: 1.2691, Balanced Acc: 0.2500, Macro F1: 0.0167, Mean AUC: 0.5076
Epoch [4/50], Loss: 1.2691
 for train: Epoch  Balanced Acc: 0.2500, Macro F1: 0.1001, Mean AUC: 0.4907
Epoch [5/50], Loss: 1.2678, Balanced Acc: 0.2500, Macro F1: 0.0167, Mean AUC: 0.6598
Epoch [5/50], Loss: 1.2678
 for train: Epoch  Balanced Acc: 0.2500, Macro F1: 0.0999, M