In [20]:
import torch
import numpy as np
import curves, data, models
import plotly.express as px
import pandas as pd
import copy
px.defaults.template = 'plotly_white'



device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')
lin_path = '/scratch/fmager/fge/cifar10/vgg16/seed_1/polychain/checkpoint-0.pt'
path = '/scratch/fmager/fge/cifar10/vgg16/seed_1/polychain/checkpoint-200.pt'
lin_checkpoint = torch.load(lin_path, map_location=device)
checkpoint = torch.load(path, map_location=device)


Using device: cpu


In [21]:
# Arguments
dataset = 'CIFAR10'
model = 'VGG16'
transform = 'VGG'
data_path = './data'
batch_size = 128
num_workers = 4
use_test = True
curve = 'PolyChain'
num_bends = 3

In [22]:

loaders, num_classes = data.loaders(
    dataset,
    data_path,
    batch_size,
    num_workers,
    transform,
    use_test,
    shuffle_train=False
)
architecture = getattr(models, model)
curve = getattr(curves, curve)
curve_model = curves.CurveNet(
    num_classes,
    curve,
    architecture.curve,
    num_bends,
    architecture_kwargs=architecture.kwargs,
)
lin_curve_model = copy.deepcopy(curve_model)
lin_curve_model = lin_curve_model.to(device)
lin_curve_model.load_state_dict(lin_checkpoint['model_state'])

curve_model = curve_model.to(device)
curve_model.load_state_dict(checkpoint['model_state'])

Files already downloaded and verified
You are going to run models on the test set. Are you sure?
Files already downloaded and verified


<All keys matched successfully>

In [23]:
T = 50
ts = np.linspace(0.0, 1.0, T)

In [None]:
# Evaluate curve
accuracy = []
accuracy_lin = []

with torch.no_grad():
        for i, t in enumerate(ts):
                correct_lin = 0.0
                correct = 0.0
                for input, target in loaders['test']:
                        input = input.cuda(non_blocking=True)
                        target = target.cuda(non_blocking=True)

                        output_lin = lin_curve_model(input, t)
                        pred_lin = output_lin.data.argmax(1, keepdim=True)

                        output = curve_model(input, t)
                        pred = output.data.argmax(1, keepdim=True)

                        correct_lin += pred_lin.eq(target.data.view_as(pred_lin)).sum().item()
                        correct += pred.eq(target.data.view_as(pred)).sum().item()
                
                accuracy_lin.append(correct_lin * 100.0 / len(loaders['test'].dataset))
                accuracy.append(correct * 100.0 / len(loaders['test'].dataset))
                
                print(f'{i+1}/{T} - Lin Accuracy: {accuracy_lin[-1]:.2f}, Accuracy: {accuracy[-1]:.2f}')


In [None]:
# Make a plot of the linear and curve fitted accuracy
df = {
    't': ts,
    'linear curve': accuracy_lin,
    'polygonal curve': accuracy
}
df = pd.DataFrame(df)
fig = px.line(df, x='t', y=['linear curve', 'polygonal curve'], title='Curve accuracy')
fig.show()
fig.write_image('figures/accuracy_curve.png')

In [24]:
import torch
from torch.utils.data import Dataset

class Latent(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return (self.x[idx], self.y[idx])

In [None]:
print(curve_model.net)

In [29]:
# Define a hook of the curve_model, which returns the latent space before the final layer (fc3)

def forward_hook(module, input, output):
    outputs.append(output.detach())

# Remove previous hooks
if 'handle' in locals():
    handle.remove()

handle = curve_model.net.fc2.register_forward_hook(forward_hook)

latent_spaces = []
targets = []
with torch.no_grad():
    for i, t in enumerate(ts):
        outputs = []
        for j, (input, target) in enumerate(loaders['test']):
            if i == 0:
                targets.append(target.detach())
            input = input.to(device)
            target = target.to(device)
            
            _ = curve_model(input, t)

        if i == 0:
            targets = torch.cat(targets)
        latent_spaces.append(torch.cat(outputs).unsqueeze(-1))
latents = torch.cat(latent_spaces, dim=-1)
print(targets.shape)
print(latents.shape)
handle.remove()



torch.Size([10000])
torch.Size([10000, 512, 50])


In [None]:
# init and save the dataset
datast = Latent(latents, targets)
torch.save(datast, f'./data/{dataset}_{model}_latent_space.pt')
