In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import time
import os
import torchvision
from torchvision import models
import timm
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import time
from tqdm import tqdm

In [None]:
class SEBlock(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(in_channels, in_channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(in_channels // reduction, in_channels, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

class SEBottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1, downsample=None, reduction=16, dropout_prob=0.5):
        super(SEBottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.se = SEBlock(planes * self.expansion, reduction)
        self.downsample = downsample
        self.stride = stride
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out = self.se(out)
        out = self.dropout(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

class SEResNet50(nn.Module):
    def __init__(self, num_classes=1000, dropout_prob=0.5):
        super(SEResNet50, self).__init__()
        self.num_classes = num_classes
        
        self.inplanes = 64
        resnet = models.resnet50(weights='IMAGENET1K_V2')
        self.conv1 = resnet.conv1
        self.bn1 = resnet.bn1
        self.relu = resnet.relu
        self.maxpool = resnet.maxpool

        self.layer1 = self._make_layer(SEBottleneck, 64, 3)
        self.layer2 = self._make_layer(SEBottleneck, 128, 4, stride=2)
        self.layer3 = self._make_layer(SEBottleneck, 256, 6, stride=2)
        self.layer4 = self._make_layer(SEBottleneck, 512, 3, stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc1 = nn.Linear(512 * SEBottleneck.expansion, 512)
        self.fc2 = nn.Linear(512, num_classes)

        self.dropout = nn.Dropout(p=dropout_prob)

        # Freeze initial layers
        for param in self.conv1.parameters():
            param.requires_grad = False
        for param in self.bn1.parameters():
            param.requires_grad = False
        for param in self.layer1.parameters():
            param.requires_grad = False

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

    def _make_layer(self, block, planes, blocks, stride=1, reduction=16):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, reduction))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, reduction=reduction))

        return nn.Sequential(*layers)

In [None]:
class ViT(nn.Module):
    def __init__(self, num_classes=1000):
        super(ViT, self).__init__()
        self.model_name = 'vit_base_patch16_224'
        self.num_classes = num_classes
        self.model = timm.create_model(self.model_name, pretrained=True, num_classes=num_classes)

        # Define the last layer for classification
        self.model.head = nn.Linear(self.model.head.in_features, self.num_classes)

    def forward(self, x):
        # Pass input through the pre-trained model
        output, intermediates = self.model.forward_intermediates(x)
        return output, intermediates

In [None]:
def get_data_loaders(data_dir, batch_size=32,
                     resize=(256, 256), crop=(224, 224),
                     mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
    train_dir = os.path.join(data_dir, 'train')
    val_dir = os.path.join(data_dir, 'val')
    test_dir = os.path.join(data_dir, 'test')

    train_transform = transforms.Compose([
        transforms.Resize(resize),
        transforms.RandomCrop(crop),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),  
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),  
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])

    val_transform = transforms.Compose([
        transforms.Resize(resize),
        transforms.CenterCrop(crop),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])

    test_transform = transforms.Compose([
        transforms.Resize(resize),
        transforms.CenterCrop(crop),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])

    train_dataset = datasets.ImageFolder(train_dir, transform=train_transform)
    val_dataset = datasets.ImageFolder(val_dir, transform=val_transform)
    test_dataset = datasets.ImageFolder(test_dir, transform=test_transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

    return train_loader, val_loader, test_loader

In [None]:
def evaluate_model(model, data_loader, criterion=None, device='cuda'):
    model.eval()
    correct = 0
    total = 0
    running_loss = 0.0
    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            if criterion:
                loss = criterion(outputs, labels)
                running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = correct / total
    if criterion:
        avg_loss = running_loss / len(data_loader.dataset)
        return avg_loss, accuracy
    return accuracy

In [None]:
def save_model(model, file_name):
    # Define the directory
    dir_name = "trained_models/"
    
    # Create the directory if it doesn't exist
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)
    
    # Save the model's state dictionary to the specified file
    file_path = os.path.join(dir_name, file_name)
    torch.save(model.state_dict(), file_path)

In [None]:
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler_gamma, scheduler_step_size, 
                num_epochs, device, patience, file_name):

    start_time = time.time()

    best_val_loss = float('inf')
    best_val_acc = 0.0
    counter = 0
    
    if scheduler_step_size is not None and scheduler_gamma is not None:
        scheduler_bool = True
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step_size, gamma=scheduler_gamma)
    else:
        scheduler_bool = False

    for epoch in range(num_epochs):
        print("\n", '-'*10)
        if scheduler_bool:
            current_lr = scheduler.get_last_lr()[0]
        else:
            current_lr = optimizer.param_groups[0]['lr']
        print(f'Epoch {epoch+1}/{num_epochs}, Current Learning Rate: {current_lr}')     
        e_start = time.time()
        
        # Training phase
        model.train()
        running_loss = 0.0
        for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch} - Training"):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
        
        epoch_loss = running_loss / len(train_loader.sampler)
        train_acc = evaluate_model(model, train_loader, None, device)
        
        # Validation phase
        val_loss, val_acc = evaluate_model(model, val_loader, criterion, device)
        
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Val Loss: {val_loss:.4f}')
        print(f'Train Accuracy: {train_acc:.4f}, Val Accuracy: {val_acc:.4f}')
        print('Epoch time: ', time.time() - e_start)
        
        # Check for best validation accuracy and save model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            save_model(model, file_name)
            print(f'Best model saved with val accuracy: {best_val_acc:.4f}')
        
        # Check for best validation loss and early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            counter = 0
        else:
            counter += 1
            if counter >= patience:
                print(f'Validation loss did not improve for {patience} epochs. Early stopping...')
                break
        
        if scheduler_bool:
            scheduler.step()

    
    print('Training complete. Best validation accuracy: {:.4f}'.format(best_val_acc))
    print('Total time: ', time.time() - start_time)

    return model

In [None]:
model_name = 'vit' # 'vit' or 'resnet' or 'densenet' or 'efficientnet'
dataset = 'cub' # 'cub' or 'cars' or 'aircrafts'
data_dir = 'datasets/' + dataset
batch_size = 32
num_epochs = 20
dropout_rate = 0.5
learning_rate = 0.001
momentum = 0.9
weight_decay = 0.001
criteria = 'cross_entropy'
optimz = 'adam' # 'adam' or 'sgd'
scheduler_step_size = 5
scheduler_gamma = 0.1
patience = 5
resize = 256
crop_size = 224
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = len(os.listdir(data_dir + '/train'))

In [None]:
if model_name == 'resnet':
    model = SEResNet50(num_classes=num_classes, dropout_prob = dropout_rate).to(device)
elif model_name == 'vit':
    model = ViT(num_classes=num_classes).to(device)
elif model_name == 'efficientnet':
    model = torchvision.models.efficientnet_v2_s(weights=torchvision.models.EfficientNet_V2_S_Weights.IMAGENET1K_V1)
    # Freeze the first layer
    first_layer = model.features[0]
    for param in first_layer.parameters():
        param.requires_grad = False
    # Modify the classifier for the desired number of output classes
    in_features = model.classifier[1].in_features  # Access the in_features of the second layer of the classifier
    model.classifier[1] = nn.Linear(in_features, num_classes)
    model = model.to(device)
    # raise ValueError("Model not found")

In [None]:
if criteria == 'cross_entropy':
    criterion = nn.CrossEntropyLoss()
else:
    raise ValueError("Criterion not found")

In [None]:
if optimz == 'adam':
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
else:
    raise ValueError("Optimizer not found")
file_name = f'{model_name}_{dataset}_{num_epochs}e_bs{batch_size}_lr{learning_rate}_dr{dropout_rate}_c{criteria}_o{optimz}_sg{scheduler_gamma}_sss{scheduler_step_size}.pth'

In [None]:
train_loader, val_loader, test_loader = get_data_loaders(data_dir, batch_size, resize, crop_size, mean, std)

In [None]:
model = train_model(model, train_loader, val_loader, criterion, optimizer, scheduler_gamma, scheduler_step_size, 
            num_epochs, device, patience, model_name, file_name)

In [None]:
accuracy = evaluate_model(model, test_loader, criterion = None, device = device)