In [9]:
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F

import pytorch_lightning as pl
from torch.utils.data import DataLoader
from pytorch_lightning.loggers import WandbLogger
from unsupervised_meta_learning.proto_utils import prototypical_loss, get_prototypes, CAE
from unsupervised_meta_learning.pl_dataloaders import UnlabelledDataModule, UnlabelledDataset
from unsupervised_meta_learning.protoclr import ProtoCLR

In [2]:
dm = UnlabelledDataModule('omniglot', './data/untarred/', split='train', transform=None,
                 n_support=1, n_query=3, n_images=None, n_classes=None, batch_size=50,
                 seed=10)

model = ProtoCLR(model=CAE(1, 64, hidden_size=64), n_support=1, n_query=3, batch_size=50, lr_decay_step=25000, lr_decay_rate=.5, ae=True)

In [5]:
logger = WandbLogger(
    project='ProtoCLR+AE',
    config={
        'batch_size': 50,
        'steps': 100,
        'dataset': "omniglot"
    }
)
trainer = pl.Trainer(
        profiler='simple',
        max_epochs=10000,
        fast_dev_run=False,
        num_sanity_val_steps=2, gpus=1, #logger=logger
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [6]:
ae = CAE(1, 64, hidden_size=64)

In [27]:
dataset_train = UnlabelledDataset(
    dataset='omniglot',
    datapath='./data/untarred/',
    split='train',
    n_support=1,
    n_query=3
)
dataloader_train = DataLoader(dataset_train,
                                      batch_size=50,
                                      shuffle=True,
                                      num_workers=8,
                                      pin_memory=torch.cuda.is_available())

In [28]:
x = next(iter(dataloader_train))

In [29]:
data = x['data'] # [batch_size x ways x shots x image_dim]

In [30]:
data = data.unsqueeze(0)

In [31]:
data.shape

torch.Size([1, 50, 4, 1, 28, 28])

In [32]:
batch_size = data.size(0)
ways = data.size(1)

In [33]:
batch_size, ways

(1, 50)

In [34]:
x_support = data[:,:,:1]
x_support = x_support.reshape((batch_size, ways * 1, *x_support.shape[-3:])) # e.g. [1,50*n_support,*(3,84,84)]
x_query = data[:,:,1:]
x_query = x_query.reshape((batch_size, ways * 3, *x_query.shape[-3:])) # e.g. [1,50*n_query,*(3,84,84)]

# Create dummy query labels
y_query = torch.arange(ways).unsqueeze(0).unsqueeze(2) # batch and shot dim
y_query = y_query.repeat(batch_size, 1, 1)
y_query = y_query.view(batch_size, -1).to('cuda')

y_support = torch.arange(ways).unsqueeze(0).unsqueeze(2) # batch and shot dim
y_support = y_support.repeat(batch_size, 1, 1)
y_support = y_support.view(batch_size, -1).to('cuda')

In [38]:
# Extract features (first dim is batch dim)
x = torch.cat([x_support, x_query], 1) # e.g. [1,50*(n_support+n_query),*(3,84,84)]
z, r = ae.forward(x)
z_support = z[:,:ways * 1] # e.g. [1,50*n_support,*(3,84,84)]
z_query = z[:,ways * 1:] # e.g. [1,50*n_query,*(3,84,84)]

r_supp = r[:,:ways * 1]
r_query = r[:,ways * 1:]

In [39]:
r_supp.shape

torch.Size([1, 50, 1, 28, 28])

In [40]:
r_query.shape

torch.Size([1, 150, 1, 28, 28])

In [49]:
rs1 = r_supp.view(1, 50, 1, 1, 28, 28)[0][0]

In [52]:
rq1 = r_query.view(1, 50, 3, 1, 28, 28)[0][0]

In [88]:
F.mse_loss(rs1, rq1, reduction='none').sum()

  F.mse_loss(rs1, rq1, reduction='none').sum()


tensor(6.7648e-10, grad_fn=<SumBackward0>)

In [81]:
a = (rs1[0]- rq1[0])

In [76]:
b = (rs1[0]- rq1[1])

In [78]:
c = (rs1[0]- rq1[2])

In [85]:
(a**2 + b ** 2+ c**2).sum(dim=[1,2])

tensor([6.7648e-10], grad_fn=<SumBackward1>)