In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torchsummary import summary
from torch.utils.data.sampler import SubsetRandomSampler

import numpy as np
from tqdm import tqdm

from models import HMT
from plots import *
from FRDEEP import FRDEEPN, FRDEEPF

In [12]:
valid_size    = 110    # number of samples for validation
batch_size    = 16     # number of samples per mini-batch
num_classes   = 2      # The number of output classes. FRI/FRII
lr0           = torch.tensor(1e-2)  # The speed of convergence
momentum      = torch.tensor(9e-1)  # momentum for optimizer
num_batches   = 55     # multiplies up the total samples to ~30k like in paper
class_weights = torch.FloatTensor([0.6,0.4]) # for training
random_seed   = 42

In [13]:
transform = transforms.Compose([
#    transforms.CenterCrop(28),
    transforms.RandomRotation(0.,360.),
    transforms.ToTensor(),
    transforms.Normalize((0,), (1,))])

train_data = FRDEEPF('first', train=True, download=True, transform=transform)

num_train = len(train_data)
indices = list(range(num_train))
split = valid_size

np.random.seed(random_seed)
np.random.shuffle(indices)

train_idx, valid_idx = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, sampler=train_sampler)
valid_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, sampler=valid_sampler)

Files already downloaded and verified


In [14]:
test_data = FRDEEPF('first', train=False, transform=transform)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True)

In [15]:
model = HMT()
learning_rate = lr0
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=1e-6)
criterion = nn.CrossEntropyLoss(weight=class_weights)

In [16]:
summary(model, (1, 150, 150))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 6, 150, 150]             732
       BatchNorm2d-2          [-1, 6, 150, 150]              12
         MaxPool2d-3            [-1, 6, 75, 75]               0
            Conv2d-4           [-1, 16, 75, 75]           2,416
       BatchNorm2d-5           [-1, 16, 75, 75]              32
         MaxPool2d-6           [-1, 16, 25, 25]               0
            Conv2d-7           [-1, 24, 25, 25]           3,480
       BatchNorm2d-8           [-1, 24, 25, 25]              48
            Conv2d-9           [-1, 24, 25, 25]           5,208
      BatchNorm2d-10           [-1, 24, 25, 25]              48
           Conv2d-11           [-1, 16, 25, 25]           3,472
      BatchNorm2d-12           [-1, 16, 25, 25]              32
        MaxPool2d-13             [-1, 16, 5, 5]               0
          Flatten-14                  [

(tensor(250234), tensor(250234))

In [19]:
epochs = 10
epoch_trainaccs, epoch_validaccs = [], []

classes = ('FRI', 'FRII')

for epoch in range(epochs):

    model.train()
    train_accs=[]; acc = 0
    for iter in range(num_batches):
        for batch, (x_train, y_train) in enumerate(train_loader):
            model.zero_grad()
            pred = model(x_train)
            loss = criterion(pred,y_train)
            loss.backward()
            optimizer.step()
            acc = (pred.argmax(dim=-1) == y_train).to(torch.float32).mean()
            train_accs.append(acc.mean().item())

    print('Epoch: {}, Loss: {}, Train Accuracy: {}'.format(epoch, loss, np.mean(train_accs)))


    with torch.no_grad():
        model.eval()
        valid_losses, valid_accs = [], []; acc = 0
        for iter in range(num_batches):
            for i, (x_val, y_val) in enumerate(valid_loader):
                valid_pred = model(x_val)
                loss = criterion(valid_pred,y_val)
                acc = (valid_pred.argmax(dim=-1) == y_val).to(torch.float32).mean()
                valid_losses.append(loss.item())
                valid_accs.append(acc.mean().item())

    print('Epoch: {}, Loss: {}, Validation Accuracy: {}'.format(epoch, np.mean(valid_losses), np.mean(valid_accs)))
    epoch_trainaccs.append(np.mean(train_accs))
    epoch_validaccs.append(np.mean(valid_accs))
    
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the test images: %d %%' % (100 * correct / total))
    
    class_correct = list(0. for i in range(2))
    class_total = list(0. for i in range(2))

    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            c = (predicted == labels).squeeze()
            for i in range(len(labels)):
                label = labels[i]
                class_correct[label] += c[i].item()
                class_total[label] += 1

    for i in range(len(classes)):
        print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))
    
    print('#####################################################################')    
    torch.save(model,'model_1.out')
print("Final validation error: ",100.*(1 - epoch_validaccs[-1]))

#plot_error(epoch_trainaccs, epoch_validaccs)

Epoch: 0, Loss: 0.034753382205963135, Train Accuracy: 0.9989448051948052
Epoch: 0, Loss: 2.399550221678529, Validation Accuracy: 0.8181122442344566
Accuracy of the network on the test images: 94 %
Accuracy of   FRI : 95 %
Accuracy of  FRII : 92 %
#####################################################################
Epoch: 1, Loss: 8.701899787411094e-06, Train Accuracy: 0.9996753246753247
Epoch: 1, Loss: 3.238989293479795, Validation Accuracy: 0.8280612236493594
Accuracy of the network on the test images: 94 %
Accuracy of   FRI : 95 %
Accuracy of  FRII : 92 %
#####################################################################
Epoch: 2, Loss: 0.0010835780994966626, Train Accuracy: 0.9998376623376624
Epoch: 2, Loss: 3.044120230549578, Validation Accuracy: 0.845709646677042
Accuracy of the network on the test images: 94 %
Accuracy of   FRI : 95 %
Accuracy of  FRII : 92 %
#####################################################################
Epoch: 3, Loss: 2.0778017642442137e-05, Train Ac

KeyboardInterrupt: 