In [1]:
from fevd_vqvae.models import VQModel
import torch
import torch.nn as nn
from fevd_vqvae.models.model_modules import Encoder
from fevd_vqvae.models.vector_quantizer import VectorQuantizer
from fevd_vqvae.models.utils import instantiate_from_config
from fevd_vqvae.models.loss import VQLoss
from omegaconf import OmegaConf
from fevd_vqvae.utils import setup_dataloader
from fevd_vqvae.models.model_modules import Normalize
cfg_dict = OmegaConf.load("configs/baseline.yaml")
model_cfg_dict = cfg_dict['model']
train_cfg_dict = cfg_dict['setup']
model_cfg_dict
device = torch.device('cuda')
import numpy as np


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_cfg_dict

{'key': '2d', 'embed_dim': 192, 'n_embed': 1024, 'ddconfig': {'double_z': False, 'z_channels': 256, 'resolution': 64, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 1, 2, 2, 4], 'num_res_blocks': 2, 'attn_resolutions': [16], 'dropout': 0.0}, 'lossconfig': {'codebook_weight': 1.0, 'pixelloss_weight': 1.0, 'perceptual_weight_2d': 1.0, 'fvd_mu_weight': 0.0, 'fvd_cov_weight': 0.0}}

In [3]:

class VQModel(nn.Module):
    def __init__(self,
                 key,
                 ddconfig,
                 lossconfig,
                 n_embed,
                 embed_dim,
                 ckpt_sd=None):
        super().__init__()
        self.encoder = Encoder(**ddconfig)
        self.loss = VQLoss(**lossconfig)
        self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25)
        self.quant_conv = nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)

        if key == '2d':
            #self.decoder = Decoder2D(**ddconfig)
            self.post_quant_conv = nn.Conv3d(embed_dim, ddconfig["z_channels"], 1)

    def encode(self, x):

        input_is_videos = len(x.shape) == 5
        if input_is_videos:
            B, T, C, H, W = x.shape  # the input is a batch of videos (B, T, C, H, W)
            x = x.reshape(B * T, C, H, W)

        h = self.encoder(x)
        h = self.quant_conv(h)
        quant, emb_loss, info = self.quantize(h)

        if input_is_videos:
            _, C, H, W = quant.shape
            quant = quant.reshape(B, T, C, H, W)
        return quant, emb_loss, info

    def forward(self, real_videos):
        quant, codebook_loss, _ = self.encode(real_videos)
        print("Quant Shape: ", quant.shape)
        quant = quant.permute((0, 2, 1, 3, 4))
        return self.post_quant_conv(quant)
        #rec_videos = self.decode(quant)
        #return rec_videos, codebook_loss

In [4]:
model = VQModel(**model_cfg_dict)

In [5]:
train_dataloader = setup_dataloader(root_dir_path="data/dataset/preprocessed_data/", **train_cfg_dict['train_dataloader'])
eval_dataloaders = setup_dataloader(root_dir_path="data/dataset/preprocessed_data/", **train_cfg_dict['eval_dataloader'])

In [6]:
def nonlinearity(x):
    # swish
    return x*torch.sigmoid(x)

class Upsample(nn.Module):
    def __init__(self, in_channels, with_conv):
        super().__init__()
        self.with_conv = with_conv
        if self.with_conv:
            self.conv = torch.nn.Conv3d(in_channels,
                                        in_channels,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)

    def forward(self, x):
        x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
        if self.with_conv:
            x = self.conv(x)
        return x

class AttnBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels

        self.norm = Normalize(in_channels)
        self.q = torch.nn.Conv3d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.k = torch.nn.Conv3d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.v = torch.nn.Conv3d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.proj_out = torch.nn.Conv3d(in_channels,
                                        in_channels,
                                        kernel_size=1,
                                        stride=1,
                                        padding=0)


    def forward(self, x):
        h_ = x
        h_ = self.norm(h_)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        # Compute spatiotemporal attention
        b,c,t, h,w = q.shape
        print("Q Shape: ", q.shape)
        q = q.reshape(b,c, t*h*w)
        q = q.permute(0,2,1)   # b,t*hw,c
        k = k.reshape(b,c,t*h*w) # b,c,t*hw
        w_ = torch.bmm(q,k)     # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
        w_ = w_ * (int(c)**(-0.5))
        w_ = torch.nn.functional.softmax(w_, dim=2)

        # attend to values
        v = v.reshape(b,c,t*h*w)
        w_ = w_.permute(0,2,1)   # b,hw,hw (first hw of k, second of q)
        h_ = torch.bmm(v,w_)     # b, c,t*hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
        h_ = h_.reshape(b,c,t,h,w)

        h_ = self.proj_out(h_)

        return x+h_

class ResnetBlock(nn.Module):
    def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
                 dropout, temb_channels=512):
        super().__init__()
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut

        self.norm1 = Normalize(in_channels)
        self.conv1 = torch.nn.Conv3d(in_channels,
                                     out_channels,
                                     kernel_size=3,
                                     stride=1,
                                     padding=1)
        if temb_channels > 0:
            self.temb_proj = torch.nn.Linear(temb_channels,
                                             out_channels)
        self.norm2 = Normalize(out_channels)
        self.dropout = torch.nn.Dropout(dropout)
        self.conv2 = torch.nn.Conv3d(out_channels,
                                     out_channels,
                                     kernel_size=3,
                                     stride=1,
                                     padding=1)
        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
                self.conv_shortcut = torch.nn.Conv3d(in_channels,
                                                     out_channels,
                                                     kernel_size=3,
                                                     stride=1,
                                                     padding=1)
            else:
                self.nin_shortcut = torch.nn.Conv3d(in_channels,
                                                    out_channels,
                                                    kernel_size=1,
                                                    stride=1,
                                                    padding=0)

    def forward(self, x, temb):
        h = x
        h = self.norm1(h)
        h = nonlinearity(h)
        h = self.conv1(h)

        if temb is not None:
            h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]

        h = self.norm2(h)
        h = nonlinearity(h)
        h = self.dropout(h)
        h = self.conv2(h)

        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
                x = self.conv_shortcut(x)
            else:
                x = self.nin_shortcut(x)

        return x+h

class Decoder3D(torch.nn.Module):
    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
                 resolution, z_channels, give_pre_end=False, **ignorekwargs):
        super().__init__()
        self.ch = ch
        self.temb_ch = 0
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution
        self.in_channels = in_channels
        self.give_pre_end = give_pre_end

        # compute in_ch_mult, block_in and curr_res at lowest res
        in_ch_mult = (1,)+tuple(ch_mult)
        block_in = ch*ch_mult[self.num_resolutions-1]
        curr_res = resolution // 2**(self.num_resolutions-1)
        self.z_shape = (1,z_channels,curr_res,curr_res)
        print("Working with z of shape {} = {} dimensions.".format(
            self.z_shape, np.prod(self.z_shape)))

        # z to block_in
        self.conv_in = torch.nn.Conv3d(z_channels,
                                       block_in,
                                       kernel_size=3,
                                       stride=1,
                                       padding=1)

        # middle
        self.mid = nn.Module()
        self.mid.block_1 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)
        self.mid.attn_1 = AttnBlock(block_in)
        self.mid.block_2 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)

        # upsampling
        self.up = nn.ModuleList()
        for i_level in reversed(range(self.num_resolutions)):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_out = ch*ch_mult[i_level]
            for i_block in range(self.num_res_blocks+1):
                block.append(ResnetBlock(in_channels=block_in,
                                         out_channels=block_out,
                                         temb_channels=self.temb_ch,
                                         dropout=dropout))
                block_in = block_out
                if curr_res in attn_resolutions:
                    attn.append(AttnBlock(block_in))
            up = nn.Module()
            up.block = block
            up.attn = attn
            if i_level != 0:
                up.upsample = Upsample(block_in, resamp_with_conv)
                curr_res = curr_res * 2
            self.up.insert(0, up) # prepend to get consistent order

        # end
        self.norm_out = Normalize(block_in)
        self.conv_out = torch.nn.Conv2d(block_in,
                                        out_ch,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)

    def forward(self, z):
        print("Input Z: ", z.shape)
        self.last_z_shape = z.shape

        # timestep embedding
        temb = None

        # z to block_in
        h = self.conv_in(z)
        print("Conv_In: ", h.shape)

        # middle
        h = self.mid.block_1(h, temb)
        print("Block 1: ", h.shape)
        h = self.mid.attn_1(h)
        print("ATTN : ", h.shape)
        h = self.mid.block_2(h, temb)
        print("Block 2: ", h.shape)

        # upsampling
        for i_level in reversed(range(self.num_resolutions)):
            for i_block in range(self.num_res_blocks+1):
                h = self.up[i_level].block[i_block](h, temb)
                if len(self.up[i_level].attn) > 0:
                    h = self.up[i_level].attn[i_block](h)
            if i_level != 0:
                h = self.up[i_level].upsample(h)

        # end
        if self.give_pre_end:
            return h

        h = self.norm_out(h)
        h = nonlinearity(h)
        h = self.conv_out(h)
        return h

In [7]:
decoder = Decoder3D(**model_cfg_dict['ddconfig'])

Working with z of shape (1, 256, 4, 4) = 4096 dimensions.


In [8]:
for m, x in enumerate(eval_dataloaders['train']):
    print(x.shape)
    enc = model(x)
    print(enc.shape)
    #print(x.shape) # (N, T, C, H, W)
    #x = x.permute((0, 2, 1, 3, 4)) #(N, C, T, H, W) for 3D conv
    #print(model(x).shape)
    break
decoded = decoder(enc)

torch.Size([10, 12, 3, 64, 64])
Quant Shape:  torch.Size([10, 12, 192, 4, 4])
torch.Size([10, 256, 12, 4, 4])
Input Z:  torch.Size([10, 256, 12, 4, 4])
Conv_In:  torch.Size([10, 512, 12, 4, 4])
Block 1:  torch.Size([10, 512, 12, 4, 4])
Q Shape:  torch.Size([10, 512, 12, 4, 4])
ATTN :  torch.Size([10, 512, 12, 4, 4])
Block 2:  torch.Size([10, 512, 12, 4, 4])
Q Shape:  torch.Size([10, 256, 48, 16, 16])
Q Shape:  torch.Size([10, 256, 48, 16, 16])
Q Shape:  torch.Size([10, 256, 48, 16, 16])


: 

: 