In [1]:
import os
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import pandas as pd
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [4]:
def load(path, transform=None):
    dataset = ImageFolder(root=path,transform = transform)
    
    train_size = int(0.7 * len(dataset))
    test_size = len(dataset) - train_size
    
    train_set, test_set = torch.utils.data.random_split(dataset, [train_size, test_size])
    
    train_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=50,
        num_workers=0,
        shuffle=False
    )
    
    test_loader = torch.utils.data.DataLoader(
        test_set,
        batch_size=50,
        num_workers=0,
        shuffle=False
    )

    return train_loader, test_loader

train_loader, test_loader = load('data/imagenet-mini/train/', transform)

In [5]:
model_path = 'models/model_MLP_scripted.pt'
model = torch.jit.load(model_path, map_location=device).to(device)

In [6]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

In [7]:
num_epochs = 5

In [8]:
log_df = pd.DataFrame(columns=['epoch', 'accuracy', 'val_loss'])

In [9]:
for epoch in tqdm(range(num_epochs), desc="Total Training Progress"):
    model.train()
    
    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False):
        images, labels = images.to(device), labels.to(device)
        
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    scheduler.step()

    model.eval()
    val_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            val_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
    accuracy = 100 * correct / total
    print(f'Epoch {epoch+1}/{num_epochs}, Accuracy: {accuracy:.2f}%, Validation Loss: {val_loss/len(val_loader):.4f}')

    log_df = log_df.append({'epoch': epoch+1, 'accuracy': accuracy, 'val_loss': val_loss/len(val_loader)}, ignore_index=True)

log_df.to_csv('training_log.csv', index=False)


scripted_trained_model = torch.jit.script(model)
scripted_trained_model.save("trained_model_scripted_CNN.pt")

Total Training Progress:   0%|          | 0/5 [00:00<?, ?it/s]
Epoch 1/5:   0%|          | 0/487 [00:00<?, ?it/s][A
Epoch 1/5:   0%|          | 1/487 [00:14<1:54:18, 14.11s/it][A
Epoch 1/5:   0%|          | 2/487 [00:27<1:51:03, 13.74s/it][A
Epoch 1/5:   1%|          | 3/487 [00:40<1:47:59, 13.39s/it][A
Epoch 1/5:   1%|          | 4/487 [00:53<1:46:38, 13.25s/it][A
Epoch 1/5:   1%|          | 5/487 [01:06<1:45:47, 13.17s/it][A
Epoch 1/5:   1%|          | 6/487 [01:19<1:45:15, 13.13s/it][A
Epoch 1/5:   1%|▏         | 7/487 [01:32<1:44:48, 13.10s/it][A
Epoch 1/5:   2%|▏         | 8/487 [01:45<1:44:32, 13.09s/it][A
Epoch 1/5:   2%|▏         | 9/487 [01:58<1:44:13, 13.08s/it][A
Epoch 1/5:   2%|▏         | 10/487 [02:11<1:43:57, 13.08s/it][A
Epoch 1/5:   2%|▏         | 11/487 [02:25<1:43:55, 13.10s/it][A
Epoch 1/5:   2%|▏         | 12/487 [02:38<1:43:36, 13.09s/it][A
Epoch 1/5:   3%|▎         | 13/487 [02:51<1:43:08, 13.06s/it][A
Epoch 1/5:   3%|▎         | 14/487 [03:04<1:43:

KeyboardInterrupt: 