In [None]:
import time

import torch
import torchvision

from sairg_utils import launch, get_data_loaders, train, test

In [None]:
# num_epochs = 10
# tik = time.time()
# launch(num_epochs=num_epochs, num_proc=2)
# tok = time.time()
# tok - tik

In [None]:
base_net_weights = torchvision.models.ResNet18_Weights.DEFAULT
base_net = torchvision.models.resnet18(weights=base_net_weights)
module_list = [k for k, v in base_net.named_modules() if 'layer' in k and '.' not in k]
# module_list = [k for k, v in base_net.named_modules() if 'bn' in k] + ['fc']
print(len(module_list))
module_list

In [None]:
# base_net_weights = torchvision.models.ViT_B_32_Weights.DEFAULT
# base_net = torchvision.models.vit_b_32(weights=torchvision.models.ViT_B_32_Weights.DEFAULT)
# [k for k, v in base_net.named_modules() if 'encoder_layer' in k and len(k.split('.')) < 4]

In [None]:
activations = {}
def get_activations(name):
    def hook(model, input, output):
#         activations[name] = torch.flatten(output.detach(), start_dim=1)
        activations[name] = output.detach()
    return hook

for name, module in base_net.named_modules():
    if name in module_list:
        module.register_forward_hook(get_activations(name))

In [None]:
dataset = torchvision.datasets.CIFAR10
transform = base_net_weights.transforms()
train_loader, test_loader = get_data_loaders(dataset, transform)

# iter(train_loader).next()

In [None]:
device = torch.device("cuda")
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(base_net.parameters(), lr=1e-3)
base_net.to(device)
tik = time.time()
loss = train(base_net, loss_fn, optimizer, device, train_loader)
tok = time.time()
print(loss)
print(tok - tik)

In [None]:
device = torch.device("cuda")
base_net.to(device)
acc = test(base_net, device, test_loader)
print(acc)

In [None]:
dataiter = iter(trainloader)
images, labels = dataiter.next()
images, labels = images.to(device), labels.to(device)
print(images.shape)

# for name, module in base_net.named_modules():
#     print(f"{name}: {activations[name].shape}")
def get_input_layers(images):
    _ = base_net(images)
    layers = [activations["conv_layers.0"],
              activations["conv_layers.1"],
              activations["fc_layer1"],
              activations["fc_layer2.1"],
              activations["fc_layer2"]]
    num_feats = max([max(layer.shape) for layer in layers])
    trans_input = torch.zeros((4, len(layers), num_feats), device=images.device)
    for i, layer in enumerate(layers):
        trans_input[:, i, :layer.shape[-1]] = layer
    
    return trans_input

trans_input = get_input_layers(images)
print(trans_input.shape)

In [None]:
class MetaTrans(nn.Module):
    
    def __init__(self, d_model, num_layers):
        super().__init__()
        self.transformer = nn.Transformer(
            d_model = d_model,
            num_encoder_layers = num_layers,
            num_decoder_layers = num_layers,
            activation = "gelu",
            batch_first = True)
        self.decoder = nn.Linear(d_model, 10)
        
    def forward(self, src):
        output = self.transformer(src, src)
        output = self.decoder(output)
        return output[:, 0, :]

In [None]:
trans_net = MetaTrans(d_model = trans_input.shape[-1], num_layers = 2)
trans_net.to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(trans_net.parameters(), lr=1e-5)

for epoch in range(2):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        trans_input = get_input_layers(images)
        outputs = trans_net(trans_input)
        loss = loss_fn(outputs, labels)
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        samples = 2000
        if i % samples == samples-1:
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / samples:.3f}')
            running_loss = 0.0

In [None]:
correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        # calculate outputs by running images through the network
        trans_input = get_input_layers(images)
        outputs = trans_net(trans_input)
        # the class with the highest energy is what we choose as prediction
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

In [None]:
print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')

In [None]:
 # vit_net = SimpleViT(
 36 #     image_size = 32,
 37 #     patch_size = 8,
 38 #     num_classes = len(classes),
 39 #     dim = 1024,
 40 #     depth = 6,
 41 #     heads = 16,
 42 #     mlp_dim = 2048
 43 # )