In [None]:
# import definitions of classes and functions for learning by confusion
from lbc_utils import *
import pickle
from IPython.display import clear_output
import random
import torch
from torchvision import transforms
import matplotlib.pyplot as plt

In [None]:
# set folders of training set and evaluation set
TRAIN_FOLDER = 'pictures/tech_train'
EVAL_FOLDER = 'pictures/tech_eval'

In [None]:
# define the training and evaluation (confusion) loop.
# the only nonstandard lines are the definitions of Y and Y_bool, which convert the correct label y
# to a vector. that vector has the length of the number of categories and each entry corresponds to a
# different left-right splitting of the dataset.

def train_loop(dataloader, model, loss_fn, optimizer, device='cuda', subset=None):
    losses = []
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        pred = model(X)
        Y = lbc_label(y, subset).float()
        loss = loss_fn(pred, Y)
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        if batch % record_every == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            losses.append(loss)
    return losses

def confusion_loop(dataloader, model, loss_fn, n_categories, device='cuda', subset=None):
    '''
    returns running_conf, which is the error p^{err} from Eq. (2) of the article
    multiplied by the number of samples per gridpoint. also returns the loss.
    '''
    torch_weight = confusion_weight(n_categories, subset, device=device).view(1, -1)
    running_conf = torch.zeros(n_categories-1, device=device)
    running_loss = 0
    
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            
            pred_bool = torch.sigmoid(pred) > 0.5
            Y_bool = lbc_label(y, subset)
            Y = Y_bool.float()
            
            confusion = (
                1. / (1. - torch_weight) * (pred_bool != Y_bool) * (Y_bool == 1) +
                1. / (torch_weight) * (pred_bool != Y_bool) * (Y_bool == 0)
            ).sum(0)
            
            running_conf += confusion
            loss = loss_fn(pred, Y)
            running_loss += loss.item()
            
    return 0.5 * running_conf, running_loss


In [None]:
import os
from PIL import Image
import re

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, directory, transform=None):
        self.directory = directory
        self.transform = transform
        self.file_list = [f for f in os.listdir(directory)]
        self.pattern = r'technology(-?\d+)_'
        
    def __len__(self):
        return len(self.file_list)

    def filename_to_year(self, filename, pattern):  # helper function
        match = re.search(pattern, filename)
        if match is None:
            raise ValueError(f"could not get year from {filename}")
        return int(match.group(1)) + -1900
    
    def __getitem__(self, index):
        filename = self.file_list[index]
        
        # load image
        img_path = os.path.join(self.directory, filename)
        with open(img_path, 'rb') as f:
            img = Image.open(f).convert('RGB')
        if self.transform:
            img = self.transform(img)

        # load label and convert to tensor
        year = self.filename_to_year(filename, self.pattern)
        target = torch.tensor(year, dtype=torch.int64)
        
        return img, target

In [None]:
def mymodel(n_categories=130):
    '''
    load ResNet-50 and replace the final layer for LBC loss.
    n_categories: number of categories / grid points
    (implies number of grid separators = n_categories - 1)
    '''
    # Load pretrained model
    model = models.resnet50(pretrained=True)
    
    # Replace the final layer
    num_ftrs = model.fc.in_features
    n_separators = n_categories - 1
    model.fc = torch.nn.Linear(num_ftrs, n_separators)

    return model


In [None]:
# imagnet parameters. preprocess image sizes for faster training.
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


dataset = CustomDataset(directory=TRAIN_FOLDER, transform=transform)
dataset2 = CustomDataset(directory=EVAL_FOLDER, transform=transform)

# set num_workers and persistent_workers for faster dataloaders
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
dataloader2 = torch.utils.data.DataLoader(dataset2, batch_size=32, shuffle=True)

n_categories_dataset = 150  # number of categories in dataset

In [None]:
# define other required imports and function definitions here

ds_size = len(dataset2)
batch_size = 16*64

# define training parameters
learning_rate = 0.5 * 1e-4
epochs = 151
record_every = 10
subset = list(range(n_categories_dataset - 1))

n_categories_total = n_categories_dataset
n_categories = len(subset) + 1

# initialize model
model = mymodel(n_categories).cuda()
criterion = LBCWithLogitsLoss(n_categories_dataset, subset, device='cuda')

# define the optimizer
optimizer = torch.optim.Adam(model.fc.parameters(), lr=learning_rate)

# training and evaluation loop. evaluation here only on training dataset itself for speed.
losses = []
errs = []
valid_losses = []

for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")

    # validation loop including confusion signal
    conf, valid_loss = confusion_loop(dataloader2, model, criterion, n_categories_dataset, subset=subset)
    valid_losses.append(valid_loss)
    err = conf.detach().cpu().numpy() / ds_size
    errs.append(err)

    # plotting
    plt.semilogy(subset, err,'-d')
    plt.xlabel('system parameter')
    plt.ylabel('error')
    plt.show()

    plt.plot(valid_losses)
    plt.title('vlosses')
    plt.show()

    # training loop
    loss = train_loop(dataloader, model, criterion, optimizer, subset=subset)
    losses += loss

    # plot train loss
    plt.semilogy(losses)
    plt.xlabel(f'seen samples [{record_every * batch_size}]')
    plt.ylabel('loss')
    plt.show()

# evaluation of results from final loop
conf, valid_loss = confusion_loop(dataloader2, model, criterion, n_categories_dataset, subset=subset)
valid_losses.append(valid_loss)
err = conf.detach().cpu().numpy() / ds_size
errs.append(err)

# save results
results = {'errs': errs, 'losses': losses, 'valid_losses': valid_losses}