# Train Last Layer of Vision Transformer Classification Network from 'train' and 'test' Image Folders

### Import modules

In [20]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision
from torchvision import datasets, models, transforms
import os
import time
from datetime import datetime
import copy
from unidecode import unidecode

### Define training function

In [15]:
# Training loop
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)

        # Training and validation phases
        for phase in split_dirs:
            # Set model mode (train/eval)
            if phase == split_dirs[0]:
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # Zero the parameter gradients
                optimizer.zero_grad()

                # Forward pass
                with torch.set_grad_enabled(phase == split_dirs[0]):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # Backward pass and optimization in training phase
                    if phase == split_dirs[0]:
                        loss.backward()
                        optimizer.step()

                # Statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            # Adjust learning rate in the scheduler
            if phase == split_dirs[0]:
                scheduler.step()

            # Calculate epoch statistics
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            # Save the best model weights
            if phase == split_dirs[1] and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:.4f}')

    # Load best model weights
    model.load_state_dict(best_model_wts)
    return model

### Define main parameters

In [27]:
# Set your data directory and class names
data_dir = r"C:\Users\luisr\Repositories\Code Projects\flood-vision\data\samples\1"
split_dirs = ['train', 'test']
class_names = ['0', '1']
num_workers = 2

# Define data transformations
data_transforms = {
    split_dirs[0]: transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
    split_dirs[1]: transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
}

# Create datasets and data loaders
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in split_dirs}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4, shuffle=True, num_workers=num_workers) for x in split_dirs}
dataset_sizes = {x: len(image_datasets[x]) for x in split_dirs}

# Define your custom Vision Transformer model and modify the classification head for binary classification
model = models.vit_b_16(pretrained=True)
for param in model.parameters(): # Turning parameters not trainable
    param.requires_grad = False
model.heads = nn.Linear(768, 2) # A newly defined layer is created with requires_grad=True by default

# Check device availability for CUDA (GPU) usage
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Define your custom loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001)
step_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)



### Train and save the best model

In [28]:
# Set model processing device
model = model.to(device)

# Train the model
trained_model = train_model(model, criterion, optimizer, step_lr_scheduler, num_epochs=10)

# Save the trained model
torch.save(trained_model.state_dict(), 'trained_model.pth')

Epoch 1/10
----------
train Loss: 0.7307 Acc: 0.4800
test Loss: 0.7006 Acc: 0.5000

Epoch 2/10
----------
train Loss: 0.6430 Acc: 0.7200
test Loss: 0.6830 Acc: 0.5400

Epoch 3/10
----------
train Loss: 0.6015 Acc: 0.6800
test Loss: 0.6680 Acc: 0.5400

Epoch 4/10
----------
train Loss: 0.5564 Acc: 0.7800
test Loss: 0.6634 Acc: 0.6200

Epoch 5/10
----------
train Loss: 0.5126 Acc: 0.7800
test Loss: 0.6587 Acc: 0.6000

Epoch 6/10
----------
train Loss: 0.4764 Acc: 0.8600
test Loss: 0.6539 Acc: 0.6600

Epoch 7/10
----------
train Loss: 0.4515 Acc: 0.8400
test Loss: 0.6530 Acc: 0.6600

Epoch 8/10
----------
train Loss: 0.4191 Acc: 0.8800
test Loss: 0.6530 Acc: 0.6600

Epoch 9/10
----------
train Loss: 0.4151 Acc: 0.8600
test Loss: 0.6529 Acc: 0.6600

Epoch 10/10
----------
train Loss: 0.4165 Acc: 0.8800
test Loss: 0.6528 Acc: 0.6600

Training complete in 7m 19s
Best val Acc: 0.6600
