In [1]:
import torch
import torch.nn as nn

from torch.utils.data import DataLoader, Dataset
from unsupervised_meta_learning.proto_utils import CNN_4Layer
from unsupervised_meta_learning.pl_dataloaders import UnlabelledDataModule, UnlabelledDataset

  rank_zero_deprecation(


In [3]:
 dataset_train = UnlabelledDataset('omniglot',
                                          './data/untarred/', split='train',
                                          transform=None,
                                          n_images=None,
                                          n_classes=None,
                                          n_support=1,
                                          n_query=3)

In [4]:
dataset_train[0]['data'].shape

torch.Size([4, 1, 28, 28])

In [5]:
dataloader_train = DataLoader(dataset_train,
                                      batch_size=5,
                                      shuffle=True,
                                      num_workers=8,
                                      pin_memory=torch.cuda.is_available())

In [6]:
x = next(iter(dataloader_train))

In [7]:
data = x['data'] # [batch_size x ways x shots x image_dim]

In [8]:
data = data.unsqueeze(0)

In [9]:
data.shape

torch.Size([1, 5, 4, 1, 28, 28])

In [10]:
batch_size = data.size(0)
ways = data.size(1)

In [11]:
batch_size, ways

(1, 5)

In [12]:
x_support = data[:,:,:1]

In [13]:
x_support = x_support.reshape((batch_size, ways * 1, *x_support.shape[-3:])) # e.g. [1,50*n_support,*(3,84,84)]
x_query = data[:,:, 1:]
x_query = x_query.reshape((batch_size, ways * 3, *x_query.shape[-3:]))

In [14]:
x_query.shape

torch.Size([1, 15, 1, 28, 28])

In [15]:
x_support.shape

torch.Size([1, 5, 1, 28, 28])

In [16]:
x = torch.cat([x_support, x_query], 1) # e.g. [1,50*(n_support+n_query),*(3,84,84)]

In [39]:
class Encoder(nn.Module):

    def __init__(self,
                 num_input_channels : int,
                 base_channel_size : int,
                 latent_dim : int,
                 act_fn : object = nn.GELU):
        """
        Inputs:
            - num_input_channels : Number of input channels of the image. For CIFAR, this parameter is 3
            - base_channel_size : Number of channels we use in the first convolutional layers. Deeper layers might use a duplicate of it.
            - latent_dim : Dimensionality of latent representation z
            - act_fn : Activation function used throughout the encoder network
        """
        super().__init__()
        c_hid = base_channel_size
        self.net = nn.Sequential(
            nn.Conv2d(num_input_channels, c_hid, kernel_size=3, padding=1, stride=2), # 32x32 => 16x16
            act_fn(),
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.Conv2d(c_hid, 2*c_hid, kernel_size=3, padding=1, stride=2), # 16x16 => 8x8
            act_fn(),
            nn.Conv2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.Conv2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1, stride=2), # 8x8 => 4x4
            act_fn(),
            nn.Flatten(), # Image grid to single feature vector
            nn.Linear(2*16*c_hid, latent_dim)
        )

    def forward(self, x):
        return self.net(x.view(-1, *x.shape[-3:]))

In [65]:
class Decoder(nn.Module):

    def __init__(self,
                 num_input_channels : int,
                 base_channel_size : int,
                 latent_dim : int,
                 act_fn : object = nn.GELU):
        """
        Inputs:
            - num_input_channels : Number of channels of the image to reconstruct. For CIFAR, this parameter is 3
            - base_channel_size : Number of channels we use in the last convolutional layers. Early layers might use a duplicate of it.
            - latent_dim : Dimensionality of latent representation z
            - act_fn : Activation function used throughout the decoder network
        """
        super().__init__()
        c_hid = base_channel_size
        self.linear = nn.Sequential(
            nn.Linear(latent_dim, 2*16*c_hid),
            act_fn()
        )
        self.net = nn.Sequential(
            nn.ConvTranspose2d(2*c_hid, 2*c_hid, kernel_size=3, output_padding=1, padding=1, stride=2), # 4x4 => 8x8
            act_fn(),
            nn.Conv2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.ConvTranspose2d(2*c_hid, c_hid, kernel_size=3, output_padding=1, padding=1, stride=2), # 8x8 => 16x16
            act_fn(),
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.ConvTranspose2d(c_hid, num_input_channels, kernel_size=3, output_padding=1, padding=1, stride=2), # 16x16 => 32x32
            nn.Tanh() # The input images is scaled between -1 and 1, hence the output has to be bounded as well
        )

    def forward(self, x):
        x = self.linear(x)
        x = x.reshape(x.shape[0], -1, 4, 4)
        x = self.net(x)
        return x

In [66]:
enc = Encoder(num_input_channels=1, base_channel_size=64, latent_dim=64)

In [67]:
dec = Decoder(num_input_channels=1, base_channel_size=64, latent_dim=64)

In [68]:
dec(enc(x)).shape

torch.Size([20, 1, 32, 32])