In [2]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torchvision
import torchvision.transforms as transforms

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from data import SPMDataset
from models import HeightPrediction
import contextlib

from pathlib import Path

## Data Loading

In [3]:
hdf5_path = '/l/dsh_homo.hdf5'

train_loader = DataLoader(SPMDataset(hdf5_path=hdf5_path,
                                    mode='train',
                                    scan='stm',
                                    height='random'),
                          batch_size=30,
                          shuffle=True)

val_loader = DataLoader(SPMDataset(hdf5_path=hdf5_path,
                                    mode='val',
                                    scan='stm',
                                    height='random',),
                          batch_size=30,
                          shuffle=True)

X, h = next(iter(train_loader))

print("X.shape: ", X.shape) # X: [N, C, nx, ny] Random slice from the STM scan
print("h.shape", h.shape)   # h: [N]            Height of the random slice
#print(h)

X.shape:  torch.Size([30, 1, 128, 128])
h.shape torch.Size([30])


## Model Architecture

In [4]:
import torch
import torch.nn as nn
        
# CREATE YOUR MODEL HERE
# gets as input an STM image [N, C, 128, 128]
# returns height prediction  [N, 1]

class HeightPrediction(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.fc1 = nn.Linear(16 * 29 * 29, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 1)
                
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.maxpool(x)
        #print(x.shape)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x

net = HeightPrediction()

## Training Loop

In [5]:
from models import HeightPrediction

mse_loss = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

In [11]:
def train(dataloader, model, loss, optimizer):
    for batch, (X, h) in enumerate(train_loader):
        h = h.unsqueeze(1)  # Shape of h is now [N, 1]

        optimizer.zero_grad()
        outputs = net(X)
        loss = mse_loss(outputs, h.float())

        loss.backward()
        optimizer.step()

        if batch % 400 == 0:
            loss_value, current = loss.item(), batch+1
            print(f"Loss:{loss_value:>5f} [{current:>4d}/{len(train_loader)}]")

In [None]:
#train(train_loader, net, mse_loss, optimizer) #Check the loss

In [8]:
#-------------------- This function computes the accuracy on the test dataset

def compute_accuracy(testloader, net):
    net.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for X, h in testloader:
            outputs = net(X)
            _, predicted = torch.max(outputs.data, 1)
            # print(predicted)
            total += h.size(0)
            correct += (predicted == h).sum().item()
    print(f"Accuracy: {(correct/total)*100:>0.1f}%")
    return correct / total