In [2]:
import sys
import logging
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import models, transforms
from torchvision.datasets import ImageFolder
import os

def main():
    dtype = torch.cuda.FloatTensor

    # Training data paths
    train_data_dir = '/home/marafath/scratch/projection_2d/train'
    
    # Validation data paths
    val_data_dir = '/home/marafath/scratch/projection_2d/val'

    # Define transforms
    train_transform = transforms.Compose([
        transforms.Scale(256),
        transforms.RandomSizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),            
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
      ])
    
    val_transforms = transforms.Compose([
        transforms.Scale(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),            
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    # create a training data loader
    train_ds = ImageFolder(train_data_dir, transform=train_transform)
    train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=4, pin_memory=torch.cuda.is_available())
    
    # create a validation data loader
    val_ds = ImageFolder(val_data_dir, transform=train_transform)
    val_loader = DataLoader(val_ds, batch_size=16, num_workers=2, pin_memory=torch.cuda.is_available())
    
    # Create ResNet50, CrossEntropyLoss and Adam optimizer
    device = torch.device('cuda:0')
    model = models.resnet50(pretrained=False)
    # model = models.densenet121(pretrained=True)
    
    num_classes = len(train_ds.classes)
    model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
    # model.fc = torch.nn.Linear(model.classifier.in_features, num_classes) # For densenet
    
    model.type(dtype)
    loss_function = torch.nn.CrossEntropyLoss().type(dtype)
    
    for param in model.parameters():
        param.requires_grad = True
    optimizer = torch.optim.Adam(model.parameters(), 1e-3)
    
    model.load_state_dict(torch.load('/home/marafath/scratch/saved_models/best_metric_model_r50.pth'))

    # start a typical PyTorch training
    val_interval = 1
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    epc = 300 # Number of epoch
    for epoch in range(epc):
        print('-' * 10)
        print('epoch {}/{}'.format(epoch + 1, epc))
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            #print('{}/{}, train_loss: {:.4f}'.format(step, epoch_len, loss.item()))
            writer.add_scalar('train_loss', loss.item(), epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print('epoch {} average loss: {:.4f}'.format(epoch + 1, epoch_loss))

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                num_correct = 0.
                metric_count = 0
                for val_data in val_loader:
                    val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
                    val_outputs = model(val_images)
                    value = torch.eq(val_outputs.argmax(dim=1), val_labels)
                    metric_count += len(value)
                    num_correct += value.sum().item()
                metric = num_correct / metric_count
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), '/home/marafath/scratch/saved_models/best_metric_model_r50.pth')
                    print('saved new best metric model')
                print('current epoch: {} current accuracy: {:.4f} best accuracy: {:.4f} at epoch {}'.format(
                    epoch + 1, metric, best_metric, best_metric_epoch))
                writer.add_scalar('val_accuracy', metric, epoch + 1)
    print('train completed, best_metric: {:.4f} at epoch: {}'.format(best_metric, best_metric_epoch))
    writer.close()

if __name__ == '__main__':
    main()



----------
epoch 1/300
epoch 1 average loss: 0.5781
saved new best metric model
current epoch: 1 current accuracy: 0.5050 best accuracy: 0.5050 at epoch 1
----------
epoch 2/300
epoch 2 average loss: 0.5948
saved new best metric model
current epoch: 2 current accuracy: 0.5150 best accuracy: 0.5150 at epoch 2
----------
epoch 3/300
epoch 3 average loss: 0.5983
saved new best metric model
current epoch: 3 current accuracy: 0.5950 best accuracy: 0.5950 at epoch 3
----------
epoch 4/300
epoch 4 average loss: 0.5660
current epoch: 4 current accuracy: 0.5450 best accuracy: 0.5950 at epoch 3
----------
epoch 5/300
epoch 5 average loss: 0.5725
current epoch: 5 current accuracy: 0.5350 best accuracy: 0.5950 at epoch 3
----------
epoch 6/300
epoch 6 average loss: 0.5513
current epoch: 6 current accuracy: 0.5400 best accuracy: 0.5950 at epoch 3
----------
epoch 7/300
epoch 7 average loss: 0.5816
current epoch: 7 current accuracy: 0.5650 best accuracy: 0.5950 at epoch 3
----------
epoch 8/300
epoc

epoch 63 average loss: 0.4529
current epoch: 63 current accuracy: 0.5000 best accuracy: 0.6200 at epoch 16
----------
epoch 64/300
epoch 64 average loss: 0.4274
current epoch: 64 current accuracy: 0.5400 best accuracy: 0.6200 at epoch 16
----------
epoch 65/300
epoch 65 average loss: 0.4394
current epoch: 65 current accuracy: 0.5250 best accuracy: 0.6200 at epoch 16
----------
epoch 66/300
epoch 66 average loss: 0.4317
current epoch: 66 current accuracy: 0.5300 best accuracy: 0.6200 at epoch 16
----------
epoch 67/300
epoch 67 average loss: 0.4510
current epoch: 67 current accuracy: 0.5900 best accuracy: 0.6200 at epoch 16
----------
epoch 68/300
epoch 68 average loss: 0.4325
current epoch: 68 current accuracy: 0.5350 best accuracy: 0.6200 at epoch 16
----------
epoch 69/300
epoch 69 average loss: 0.4241
current epoch: 69 current accuracy: 0.5700 best accuracy: 0.6200 at epoch 16
----------
epoch 70/300
epoch 70 average loss: 0.4044
current epoch: 70 current accuracy: 0.4850 best accur

epoch 125 average loss: 0.3469
current epoch: 125 current accuracy: 0.4900 best accuracy: 0.6200 at epoch 16
----------
epoch 126/300
epoch 126 average loss: 0.3443
current epoch: 126 current accuracy: 0.4850 best accuracy: 0.6200 at epoch 16
----------
epoch 127/300
epoch 127 average loss: 0.3312
current epoch: 127 current accuracy: 0.5050 best accuracy: 0.6200 at epoch 16
----------
epoch 128/300
epoch 128 average loss: 0.3227
current epoch: 128 current accuracy: 0.5100 best accuracy: 0.6200 at epoch 16
----------
epoch 129/300
epoch 129 average loss: 0.3089
current epoch: 129 current accuracy: 0.5800 best accuracy: 0.6200 at epoch 16
----------
epoch 130/300
epoch 130 average loss: 0.3625
current epoch: 130 current accuracy: 0.5300 best accuracy: 0.6200 at epoch 16
----------
epoch 131/300
epoch 131 average loss: 0.3882
current epoch: 131 current accuracy: 0.4650 best accuracy: 0.6200 at epoch 16
----------
epoch 132/300
epoch 132 average loss: 0.3584
current epoch: 132 current accu

current epoch: 186 current accuracy: 0.5100 best accuracy: 0.6200 at epoch 16
----------
epoch 187/300
epoch 187 average loss: 0.2533
current epoch: 187 current accuracy: 0.5000 best accuracy: 0.6200 at epoch 16
----------
epoch 188/300
epoch 188 average loss: 0.2595
current epoch: 188 current accuracy: 0.5300 best accuracy: 0.6200 at epoch 16
----------
epoch 189/300
epoch 189 average loss: 0.2700
current epoch: 189 current accuracy: 0.5050 best accuracy: 0.6200 at epoch 16
----------
epoch 190/300
epoch 190 average loss: 0.2836
current epoch: 190 current accuracy: 0.5300 best accuracy: 0.6200 at epoch 16
----------
epoch 191/300
epoch 191 average loss: 0.2525
current epoch: 191 current accuracy: 0.5000 best accuracy: 0.6200 at epoch 16
----------
epoch 192/300
epoch 192 average loss: 0.2689
current epoch: 192 current accuracy: 0.5200 best accuracy: 0.6200 at epoch 16
----------
epoch 193/300
epoch 193 average loss: 0.2364
current epoch: 193 current accuracy: 0.5050 best accuracy: 0.6

epoch 248 average loss: 0.2346
current epoch: 248 current accuracy: 0.4950 best accuracy: 0.6200 at epoch 16
----------
epoch 249/300
epoch 249 average loss: 0.2104
current epoch: 249 current accuracy: 0.5300 best accuracy: 0.6200 at epoch 16
----------
epoch 250/300
epoch 250 average loss: 0.1744
current epoch: 250 current accuracy: 0.5000 best accuracy: 0.6200 at epoch 16
----------
epoch 251/300
epoch 251 average loss: 0.2075
current epoch: 251 current accuracy: 0.5300 best accuracy: 0.6200 at epoch 16
----------
epoch 252/300
epoch 252 average loss: 0.2194
current epoch: 252 current accuracy: 0.5350 best accuracy: 0.6200 at epoch 16
----------
epoch 253/300
epoch 253 average loss: 0.2152
current epoch: 253 current accuracy: 0.5100 best accuracy: 0.6200 at epoch 16
----------
epoch 254/300
epoch 254 average loss: 0.2857
current epoch: 254 current accuracy: 0.5500 best accuracy: 0.6200 at epoch 16
----------
epoch 255/300
epoch 255 average loss: 0.2702
current epoch: 255 current accu