In [40]:
import torch
from torch import nn
import webdataset as wds
import numpy as np

In [2]:
def selectLabel(x,lbl):
    # function to select desired label
    lbl_idx = ["id","sex","age","handedness","index"].index(lbl.lower())
    x = x.decode("utf-8").split(",")[lbl_idx]
    return x if lbl_idx == 0 else float(x)

def add_chan_dim(x):
    x = torch.tensor(x)
#     x = torch.transpose(x, 0, 1)
    return torch.unsqueeze(x,0)

In [3]:
s3_url = 'https://childmind.s3.us-west-1.amazonaws.com/python/childmind_train.tar' # replace 'train' with 'val' and 'test' accordingly
train_data = wds.WebDataset(s3_url).decode().map_dict(npy=add_chan_dim, cls=lambda x: selectLabel(x,'sex')).to_tuple("npy","cls")

s3_url = 'https://childmind.s3.us-west-1.amazonaws.com/python/childmind_val.tar' # replace 'train' with 'val' and 'test' accordingly
val_data = wds.WebDataset(s3_url).decode().map_dict(npy=add_chan_dim, cls=lambda x: selectLabel(x,'sex')).to_tuple("npy","cls")

In [235]:
class VAE(nn.Module): 
    def __init__(self):
        super().__init__()
        encoder_l = [self.encoder_conv_block(True)]
        for i in range(2):
            encoder_l.append(self.encoder_conv_block())
        encoder_l.append(self.encoder_conv_block(False, 32, 32, 3, 1, 0))
        encoder_l.append(nn.Flatten())
        encoder_l.append(self.encoder_linear_block(960, 10))
        self.encoder = nn.ModuleList(encoder_l)
                            
        decoder_l = [self.decoder_linear_block(10, 960)]
        decoder_l.append(self.decoder_conv_block(False, 32, 32, 3, 1, 0))
        for i in range(2):
            decoder_l.append(self.decoder_conv_block())
        decoder_l.append(self.decoder_conv_block(True))
        self.decoder = nn.ModuleList(decoder_l)
    
    def encoder_conv_block(self, is_start=False, in_channels=32, out_channels=32, kernel_size=6, stride=2, padding=2):
        if is_start:
            return nn.Sequential(
                nn.Conv2d(1, out_channels, kernel_size, stride, padding),
                nn.ReLU()
            )
        else:
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
                nn.ReLU()
            )
    def encoder_linear_block(self, in_chan, out_chan):
        return nn.Sequential(
            nn.Linear(in_chan, out_chan),
            nn.ReLU()
        )
    
    def decoder_conv_block(self, is_last=False, in_channels=32, out_channels=32, kernel_size=6, stride=2, padding=2):
        if is_last:
            return nn.Sequential(
                nn.ConvTranspose2d(in_channels, 1, kernel_size, stride, padding),
                nn.ReLU()
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding),
                nn.ReLU()
            )
    
    def decoder_linear_block(self, in_chan, out_chan):
        return nn.Sequential(
            nn.Linear(in_chan, out_chan),
            nn.ReLU()
        )
    
    def forward(self, x):
        for f in self.encoder:
            x = f(x)
        x = self.decoder[0](x)
        x = x.view(-1, 32, 1, 30)
        for i in range(1,len(self.decoder)):
            f = self.decoder[i]
            x = f(x)            
        return x

In [236]:
vae = VAE()
print(vae)

VAE(
  (encoder): ModuleList(
    (0): Sequential(
      (0): Conv2d(1, 32, kernel_size=(6, 6), stride=(2, 2), padding=(2, 2))
      (1): ReLU()
    )
    (1): Sequential(
      (0): Conv2d(32, 32, kernel_size=(6, 6), stride=(2, 2), padding=(2, 2))
      (1): ReLU()
    )
    (2): Sequential(
      (0): Conv2d(32, 32, kernel_size=(6, 6), stride=(2, 2), padding=(2, 2))
      (1): ReLU()
    )
    (3): Sequential(
      (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      (1): ReLU()
    )
    (4): Flatten(start_dim=1, end_dim=-1)
    (5): Sequential(
      (0): Linear(in_features=960, out_features=10, bias=True)
      (1): ReLU()
    )
  )
  (decoder): ModuleList(
    (0): Sequential(
      (0): Linear(in_features=10, out_features=960, bias=True)
      (1): ReLU()
    )
    (1): Sequential(
      (0): ConvTranspose2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      (1): ReLU()
    )
    (2): Sequential(
      (0): ConvTranspose2d(32, 32, kernel_size=(6, 6), stride=(2, 2), paddin

In [237]:
print(vae(torch.zeros(1,1,24,256)).shape)

torch.Size([1, 1, 24, 256])


In [None]:
def out_W(W, F, P, S):
    return np.floor((W - F + 2*P)/S + 1)

W = 256
for i in range(3):
    W = out_W(W, 6, 2, 2)

print(W)
print(out_W(W, 3, 0, 1))