In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
from torchvision import transforms
import os
import time

from torch.utils.data import DataLoader
from data.image_dataset_rgb import ImageDataset
from models.vit import VisionTransformerModel
from training_vit.train import train
from training_vit.eval import evaluate_acc, evaluate_error
from utils.dataset_utils import check_disjoint
import matplotlib.pyplot as plt
from datetime import datetime


# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

In [2]:
# Configuration
subject = 1
stride = 5
input_dims = (127, 127)
mode = 'reg'  # Choose between 'reg' (regression) and classification
batch_size = 64
epochs = 200
lr = 1e-4

In [3]:
data_subset = f'{subject:>02}_{stride:>02}_{input_dims[0]}_{input_dims[1]}'
print(data_subset)
train_dir_path = f'data/real/{data_subset}/train'
train_annotations_file_path = f'{train_dir_path}/{data_subset}_train.csv'
val_dir_path = f'data/real/{data_subset}/val'
val_annotations_file_path = f'{val_dir_path}/{data_subset}_val.csv'
test_dir_path = f'data/real/{data_subset}/test'
test_annotations_file_path = f'{test_dir_path}/{data_subset}_test.csv'

transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize for ViT input
])

train_dataset = ImageDataset(
    annotations_file=train_annotations_file_path, 
    img_dir=train_dir_path, 
    transform=transform)
val_dataset = ImageDataset(
    annotations_file=val_annotations_file_path, 
    img_dir=val_dir_path, 
    transform=transform)
test_dataset = ImageDataset(
    annotations_file=test_annotations_file_path, 
    img_dir=test_dir_path, 
    transform=transform)

assert(check_disjoint(val_dataset, test_dataset))

01_05_127_127


In [5]:


train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

train_features, train_labels = next(iter(train_loader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Label batch shape: {train_labels.size()}")

Feature batch shape: torch.Size([64, 3, 224, 224])
Label batch shape: torch.Size([64, 2])


In [None]:
# Model setup
num_classes = 2 if mode == 'reg' else 16
model = VisionTransformerModel(num_classes=num_classes)

criterion = nn.MSELoss() if mode == 'reg' else nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(f'Using device: {device}')

# Train model
loss_history, train_perf_history, val_perf_history = train(model, device, mode, criterion, optimizer, train_loader, val_loader, epochs)

Using device: cpu
hi
yayyy
batch_X
batch_X
batch_X
batch_X
batch_X
batch_X
batch_X


In [None]:
# Plot training loss
plt.plot(loss_history, label='Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
# Evaluate model
if mode == 'reg':
    test_error = evaluate_error(model, device, test_loader)
    print(f'Test error: {test_error}')
else:
    test_acc = evaluate_acc(model, device, test_loader)
    print(f'Test accuracy: {test_acc}')

# Save model
time_stamp = datetime.fromtimestamp(time.time()).strftime('%m-%d--%H-%M-%S')
save_dir = f'models/saved_models/{time_stamp}'
os.makedirs(save_dir, exist_ok=True)
torch.save(model.state_dict(), f'{save_dir}/model.pt')