In [None]:
import timm
import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets


In [None]:
def create_vit(model_name):
    if model_name == 'tiny': model = timm.create_model('vit_tiny_patch16_224', pretrained=False, num_classes=4)
    if model_name == 'small': model = timm.create_model('vit_small_patch16_224', pretrained=False, num_classes=4)
    if model_name == 'base': model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=4)

    return model

In [None]:
transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])

train_dataset = datasets.ImageFolder(root='/Users/aravdhoot/Remote-PD-Detection/split_energy_images/Train', transform=transform)
val_dataset = datasets.ImageFolder(root='/Users/aravdhoot/Remote-PD-Detection/split_energy_images/Validation', transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [None]:
model = create_vit('tiny')

# Model summary to check architecture
print(model)

In [None]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
num_epochs = 10  # Set the number of epochs

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    print(f"Epoch {epoch+1}")

    for batch_idx, (images, labels) in enumerate(train_loader):
        optimizer.zero_grad()  # Zero the parameter gradients
        outputs = model(images)  # Forward pass
        loss = criterion(outputs, labels)  # Compute loss
        loss.backward()  # Backward pass
        optimizer.step()  # Optimize

        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        print(f"Batch {batch_idx+1}, Loss: {loss.item():.6f}, Accuracy: {100 * correct / total:.2f}%")

    train_loss = running_loss / len(train_loader)
    train_accuracy = 100 * correct / total
    print(f"Training - Epoch {epoch+1}, Loss: {train_loss:.6f}, Accuracy: {train_accuracy:.2f}%")

    # Validation step
    model.eval()
    val_running_loss = 0.0
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        for images, labels in val_loader:
            outputs = model(images)
            loss = criterion(outputs, labels)

            val_running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()

    val_loss = val_running_loss / len(val_loader)
    val_accuracy = 100 * val_correct / val_total
    print(f"Validation - Epoch {epoch+1}, Loss: {val_loss:.6f}, Accuracy: {val_accuracy:.2f}%")
