In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from torch import nn

from torchmeta.datasets.helpers import omniglot
from torchmeta.datasets import Omniglot
from torchmeta.utils.data import BatchMetaDataLoader
import matplotlib.pyplot as plt
from torchmeta.transforms import Categorical
from dotted.utils import dot
import hypernet as hn
from tqdm import tqdm
import numpy as np

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

In [None]:
N, K = 5, 5

dataset = omniglot("data", ways=N, shots=K, test_shots=15, meta_train=True, download=True)
dataloader = BatchMetaDataLoader(
    dataset, batch_size=32, num_workers=0, 
    shuffle=True
)

In [None]:
for T in dataloader:
    break

In [None]:
def plot_task(
    x_train, y_train, x_test, y_test, n
):
    for (t, x,y) in [("train",x_train, y_train), ("test", x_test, y_test)]:
        k = len(x) // n
        fig, ax = plt.subplots(n, k, figsize=(1.5*k, 1.5*n))
        for n_ in range(n):
            for k_ in range(k):
                i = n_*k + k_
                ax[n_,k_].imshow(x[i].squeeze())
                ax[n_,k_].set_title(f"{y[i]}")
                ax[n_,k_].axis("off")
        fig.suptitle(t)
        
        
#     print([t.shape for t in [x_train, y_train, x_test, y_test]])

In [None]:
x_train, y_train = T["train"]
x_test, y_test = T["test"]

plot_task(x_train[0], y_train[0], x_test[0], y_test[0], N)

In [None]:
hypernet = hn.HyperNetwork(
#     target_network=target_net,
    n=N, k=K,
    hidden_size=32
)
hypernet.to(device)

loss_fn = nn.CrossEntropyLoss()
h_opt = torch.optim.Adam(hypernet.parameters(), lr=4e-4)
sum(p.numel() for p in hypernet.parameters())

In [None]:
n_tasksets = 100
taskset_epochs = 10

for task_set_id, tasks in enumerate(tqdm(dataloader, total=n_tasksets)):

    for e in range(taskset_epochs):
        X_train, Y_train = tasks["train"]
        X_test, Y_test = tasks["test"]

        X_train, X_test, Y_train, Y_test = [t.to(device) for t in [X_train, X_test, Y_train, Y_test]]


        train_losses = []
        train_accs = []

        test_losses = []
        test_accs = []
        loss_sum = 0
        

        for t_id, (x_train, y_train, x_test, y_test) in enumerate(zip(X_train, Y_train, X_test, Y_test)):
        
            tn = hypernet(x_train, y_train)
            
            y_pred_train = tn(x_train)
            
            y_pred_test = tn(x_test)

            loss = loss_fn(y_pred_train, y_train) + loss_fn(y_pred_test, y_test)
            
            loss_sum = loss_sum + loss
            
            train_acc = (y_pred_train.argmax(dim=1) == y_train).sum() / len(y_train)
            
            train_losses.append(loss.item())
            train_accs.append(train_acc.item())
            
            test_acc = (y_pred_test.argmax(dim=1) == y_test).sum() / len(y_test)
            test_accs.append(test_acc.item())
            

        loss_sum.backward()
        h_opt.step()
        h_opt.zero_grad()

        print(
            task_set_id, 
            e,
            {
                "tr_l": np.mean(train_losses).item(),
                "te_l": np.mean(test_losses).item(),
                "tr_a": np.mean(train_accs).item(),
                "te_a": np.mean(test_accs).item()
            } 
             )

    if task_set_id == n_tasksets-1:
        break
    
