In [1]:
# default_exp models

## Imports

In [2]:
# EXPORT
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from surrogates4sims.utils import printNumModelParams

## Models

### Deep Fluids Models

In [3]:
# EXPORT 
class convBlock(nn.Module):
    def __init__(self,num_conv=4,in_channels=128,filters=128,act=nn.LeakyReLU(),downSample=True,
                     norm=nn.BatchNorm2d):
        super(convBlock,self).__init__()
        self.num_conv = num_conv
        self.in_channels = in_channels
        self.act = act
        
        layers = []
        for i in range(num_conv):
            if i == 0:
                layers.append(nn.Conv2d(in_channels, filters, kernel_size=3, stride=1,padding=1))
                layers.append(norm(filters))
                layers.append(act)
            else:
                layers.append(nn.Conv2d(filters, filters, kernel_size=3, stride=1,padding=1))
                layers.append(norm(filters))
                layers.append(act)
        self.convs = nn.Sequential(*layers)
        #self.downSampleLayer = nn.Conv2d(in_channels+filters, in_channels+filters,kernel_size=3,stride=2,padding=1)
        self.downSampleLayer = nn.Conv2d(filters, filters,kernel_size=3,stride=2,padding=1)

    def forward(self,x):
        #print(x.shape)
        x0 = x
        x = self.convs(x)
        #x = torch.cat([x,x0],axis=1)
        x = self.downSampleLayer(x)
        x = torch.cat([x,F.interpolate(x0,scale_factor=.5)],axis=1)
        return x
    
class convTransBlock(nn.Module):
    def __init__(self, num_conv=4, in_channels=128, filters=128, act=nn.LeakyReLU(),
                 skip_connection=False, stack=False, norm=nn.BatchNorm2d):
        super(convTransBlock,self).__init__()
        self.filters = filters
        self.num_conv = num_conv
        self.in_channels = in_channels
        self.act = act
        self.skip_connection = skip_connection
        self.stack = stack
        self.upsample = torch.nn.modules.Upsample(scale_factor=2)
        layers = []
        for i in range(num_conv-1):
            if i == 0:
                layers.append(nn.ConvTranspose2d(in_channels,out_channels=filters, kernel_size=3, stride=1,padding=1))
                layers.append(norm(filters))
                layers.append(act)
            else:
                layers.append(nn.ConvTranspose2d(filters,out_channels=filters, kernel_size=3, stride=1,padding=1))
                layers.append(norm(filters))
                layers.append(act)                
        layers.append(nn.ConvTranspose2d(filters, out_channels=filters, kernel_size=3, stride=2,
                                         padding=1, output_padding=1))
        layers.append(act)
        self.seq = nn.Sequential(*layers)
    
    def forward(self,x):
        x0 = x
        x = self.seq(x)
        if self.skip_connection:
            x += self.upsample(x0)
        if self.stack:
            x = torch.cat([x, self.upsample(x0)], axis=1)
        #print(x.shape)
        return x
    
class Generator(nn.Module):
    def __init__(self,z, filters, output_shape,
                 num_conv=4, conv_k=3, last_k=3, repeat=0, 
                 skip_connection=False, act=nn.LeakyReLU(),stack=False, norm=nn.BatchNorm2d, sigmoid_out=False):
        super(Generator,self).__init__()
        if repeat == 0:
            repeat_num = int(np.log2(torch.max(output_shape[1:]))) - 2
        else:
            repeat_num = repeat
        x0_shape = [filters] + [int(i//np.power(2, repeat_num-1)) for i in output_shape[1:]]
        print(x0_shape)
        self.x0_shape = x0_shape
        self.output_shape = output_shape
        self.filters = filters
        self.num_conv = num_conv
        num_output = int(np.prod(x0_shape))
        self.num_output = num_output
        self.linear = nn.Linear(z.shape[1], num_output)
        
        convTransBlockLayers = []
        ch = filters
        for i in range(repeat_num-1):
            #print(ch)
            convTransBlockLayers.append(convTransBlock(num_conv, ch, filters, act, skip_connection, stack,
                                                      norm=norm))
            if stack:
                ch += filters

        self.convTransBlockLayers = nn.Sequential(*convTransBlockLayers)
        if ch > filters:
            n = ch
        else:
            n = filters
        if sigmoid_out:
            self.lastConv = nn.Sequential(nn.Conv2d(n,int(output_shape[0]),kernel_size=3, stride=1,padding=1),
                                          nn.Sigmoid())
        else:
            self.lastConv = nn.Conv2d(n,int(output_shape[0]),kernel_size=3, stride=1,padding=1)
        self.skip_connection = skip_connection
        self.stack = stack

    def forward(self, x):
        x = self.linear(x)
        x = x.view(-1,self.x0_shape[0],self.x0_shape[1],self.x0_shape[2])
        x = self.convTransBlockLayers(x)
        x = self.lastConv(x)
        return x
    
class Encoder(nn.Module):
    def __init__(self, x, filters, z_num, num_conv=4, conv_k=3, repeat=0, act=nn.LeakyReLU(), norm=nn.BatchNorm2d):
        super(Encoder,self).__init__()
        
        x_shape = x.shape[1:]
        if repeat == 0:
            repeat_num = int(np.log2(np.max(x_shape[1:]))) - 2
        else:
            repeat_num = repeat
        
        self.filters = filters
        self.act = act
        self.conv1 = nn.Conv2d(x_shape[0], filters, kernel_size=conv_k, stride=1,padding=1)
        
        ch = filters
        convLayers = []
        for idx in range(0,repeat_num):
            convLayers.append(convBlock(num_conv, ch, filters, act=nn.LeakyReLU(), downSample=True, norm=norm))
            ch += filters
            
        self.convs = nn.Sequential(*convLayers)
        h = [i//2**repeat_num for i in x_shape[1:]]
        self.nLinearInput = (ch)*np.prod(h)
        self.linear = nn.Linear(self.nLinearInput,z_num)
                             
    def forward(self,x):
        x = self.act(self.conv1(x))
        #print(x.shape)
        x = self.convs(x)
        #print(x.shape)
        x = x.view(x.size(0),-1)
        x = self.linear(x)
        return x

class Encoder_LK(nn.Module):
    def __init__(self, x, filters, z_num, repeat=0, act=nn.LeakyReLU()):
        super(Encoder_LK,self).__init__()
        
        x_shape = x.shape[1:]
        if repeat == 0:
            repeat_num = int(np.log2(np.max(x_shape[1:]))) - 2
        else:
            repeat_num = repeat
        
        self.filters = filters
        self.act = act

        self.bn0 = nn.BatchNorm2d(x_shape[0])
        self.conv1 = nn.Conv2d(x_shape[0], filters, kernel_size=13,stride=4,padding=6)
        self.bn1 = nn.BatchNorm2d(filters)
        self.conv2 = nn.Conv2d(filters, 4*filters, kernel_size=13,stride=4,padding=6)
        self.bn2 = nn.BatchNorm2d(4*filters)
        self.z_num = z_num
        self.nLinearInput = 4*filters*8*6
        self.linear = nn.Linear(self.nLinearInput,z_num)
                             
    def forward(self,x):
        x = self.bn0(x)
        x = self.act(self.conv1(x))
        #print(x.shape)
        x = self.bn1(x)
        x = self.act(self.conv2(x))
        x = self.bn2(x)
        #print(x.shape)
        x = x.view(x.size(0),-1)
        x = self.linear(x)
        return x

class Decoder_LK(nn.Module):
    def __init__(self, x, filters, z_num, repeat=0, act=nn.LeakyReLU()):
        super(Decoder_LK,self).__init__()
        
        x_shape = x.shape[1:]
        if repeat == 0:
            repeat_num = int(np.log2(np.max(x_shape[1:]))) - 2
        else:
            repeat_num = repeat
        
        self.filters = filters
        self.act = act
        self.bn0 = nn.BatchNorm2d(filters)
        self.conv1 = nn.ConvTranspose2d(filters,filters,kernel_size=25,stride=4,padding=11,output_padding=1)
        self.conv2 = nn.ConvTranspose2d(filters,4*filters,kernel_size=25,stride=4,padding=11,output_padding=1)
        self.bn1 = nn.BatchNorm2d(4*filters)
        self.conv3 = nn.Conv2d(4*filters,x_shape[0], kernel_size=3,stride=1,padding=1)
        self.z_num = z_num
        self.nLinearOutput = filters*8*6
        self.linear = nn.Linear(z_num,self.nLinearOutput)
                             
    def forward(self,x):
        x = self.act(self.linear(x))
        x = x.view(x.size(0),self.filters,8,6)
        x = self.bn0(x)
        x = self.act(self.conv1(x))
        x = self.bn0(x)
        #print(x.shape)
        x = self.act(self.conv2(x))
        x = self.bn1(x)
        x = self.conv3(x)
        return x
    
class AE_LK(nn.Module):
    def __init__(self, encoder_LK, decoder_LK, return_z=True):
        super(AE_LK,self).__init__()
        self.encoder = encoder_LK
        self.generator = decoder_LK
        self.return_z = return_z
        
    def forward(self,x):
        z = self.encoder(x)
        x = self.generator(z)
        if self.return_z:
            return x, z
        else:
            return x
        
class AE_no_P(nn.Module):
    def __init__(self, encoder,generator):
        super(AE_no_P,self).__init__()
        self.encoder = encoder
        self.generator = generator
        
    def forward(self,x):
        x = self.encoder(x)
        #x = torch.cat([x,p],axis=1)
        x = self.generator(x)
        return x
    
class AE_xhat_z(nn.Module):
    def __init__(self, encoder,generator):
        super(AE_xhat_z,self).__init__()
        self.encoder = encoder
        self.generator = generator
        
    def forward(self,x):
        z = self.encoder(x)
        x = self.generator(z)
        return x, z
    
class AE_xhat_zV2(nn.Module):
    def __init__(self, X, filters=32, latentDim=16, num_conv=2, repeat=0, 
                 skip_connection=False, stack=False, conv_k=3, last_k=3, act=nn.LeakyReLU, 
                 return_z=True, stream=True, device='cpu', norm = nn.BatchNorm2d, sigmoid_out=False):
        super(AE_xhat_zV2,self).__init__()
        
        self.filters = filters
        self.latentDim = latentDim
        self.num_conv = num_conv
        self.repeat = repeat
        self.skip_connection = skip_connection
        self.stack = stack
        self.act = act
        self.norm = norm
        self.conv_k = 3
        self.last_k = 3
        self.device = device
        self.return_z = return_z
        self.stream = stream
        
        self.encoder = Encoder(X,filters,latentDim,num_conv=num_conv,norm=norm).to(device)
        
        z = self.encoder(X)
        
        if stream:
            self.output_shape = torch.tensor(X[0][1:].shape)
        else:
            self.output_shape = torch.tensor(X[0].shape)

        self.generator = Generator(z, filters, self.output_shape,
                                   num_conv, conv_k, last_k, repeat, skip_connection, act=act,
                                   stack=stack, norm=norm, sigmoid_out=sigmoid_out).to(device)
        
    def forward(self, x, p_x):
        z = self.encoder(x)
        z[:, -p_x.size(1):] = p_x
        x = self.generator(z)
        if self.return_z:
            return x, z
        else:
            return x
        
class ConvDeconv(nn.Module):
    def __init__(self, X, filters=32, latentDim=16, num_conv=2, repeat=1, 
                 skip_connection=False, stack=False, conv_k=3, last_k=3,
                 act=nn.LeakyReLU(), return_z=True, stream=False, device='cpu'):
        super(ConvDeconv,self).__init__()
        
        self.filters = filters
        self.latentDim = latentDim
        self.num_conv = num_conv
        self.repeat = repeat
        self.skip_connection = skip_connection
        self.stack = stack
        self.act = act
        self.conv_k = 3
        self.last_k = 3
        self.device = device
        self.return_z = return_z
        self.stream = stream
        
        x_shape = X.shape
        convLayers = [nn.BatchNorm2d(x_shape[1]),nn.Conv2d(x_shape[1],filters, kernel_size=conv_k,stride=2,padding=1)]
        ch = filters
        for idx in range(0,repeat):
            convLayers.append(self.act)
            convLayers.append(nn.BatchNorm2d(ch))
            convLayers.append(nn.Conv2d(ch, ch+filters, kernel_size=conv_k,stride=2,padding=1))
            ch += filters
            
            
        self.encoder = nn.Sequential(*convLayers).to(device)
        
        z = self.encoder(X)
        
        if stream:
            self.output_shape = torch.tensor(X[0][1:].shape)
        else:
            self.output_shape = torch.tensor(X[0].shape)
        
        #print(self.output_shape)
        deconvLayers = []
        for idx in range(0,repeat):
            deconvLayers.append(nn.BatchNorm2d(ch))
            deconvLayers.append(nn.ConvTranspose2d(ch,ch-filters, 
                                                   kernel_size=conv_k,stride=2,
                                                   padding=1, output_padding=1))
            deconvLayers.append(self.act)
            ch -= filters
        
        deconvLayers.append(nn.BatchNorm2d(ch))
        deconvLayers.append(nn.ConvTranspose2d(ch,int(self.output_shape[0]), 
                                               kernel_size=conv_k,stride=2,
                                               padding=1, output_padding=1))  
        
        self.generator = nn.Sequential(*deconvLayers).to(device)
        
    def forward(self,x):
        z = self.encoder(x)
        x = self.generator(z)
        if self.return_z:
            return x, z
        else:
            return x
        
class Reshape(nn.Module):
    def __init__(self, *args):
        super(Reshape, self).__init__()
        self.shape = args

    def forward(self, x):
        s = [x.shape[0], *self.shape]
        return x.view(s) 
    
class ConvDeconvFactor2(nn.Module):
    def __init__(self, X, filters=32, latentDim=16, num_conv=2, repeat=1, 
                 skip_connection=False, stack=False, conv_k=3, last_k=3,
                 act=nn.LeakyReLU(), return_z=True, stream=False, device='cpu',
                 use_sigmoid_output_layer=False, norm = nn.BatchNorm2d):
        super(ConvDeconvFactor2,self).__init__()
        
        self.filters = filters
        self.latentDim = latentDim
        self.num_conv = num_conv
        self.repeat = repeat
        self.skip_connection = skip_connection
        self.stack = stack
        self.act = act
        self.conv_k = 3
        self.last_k = 3
        self.device = device
        self.return_z = return_z
        self.stream = stream
        self.use_sigmoid_output_layer = use_sigmoid_output_layer
        
        x_shape = X.shape
        convLayers = [norm(x_shape[1]),nn.Conv2d(x_shape[1],filters, kernel_size=conv_k,stride=2,padding=1)]
        ch = filters
        for idx in range(0,repeat):
            convLayers.append(self.act)
            convLayers.append(norm(ch))
            convLayers.append(nn.Conv2d(ch, 2*ch, kernel_size=conv_k,stride=2,padding=1))
            ch = 2*ch
        
        self.ch = ch
        self.sz = int(X.shape[2]//(2**(repeat+1)))
        convLayers.append(Reshape(-1))
        convLayers.append(nn.Linear(self.ch*self.sz*self.sz,self.latentDim))
        self.encoder = nn.Sequential(*convLayers).to(device)
        
        z = self.encoder(X)
    
        if stream:
            self.output_shape = torch.tensor(X[0][1:].shape)
        else:
            self.output_shape = torch.tensor(X[0].shape)
        
        #print(self.output_shape)
        deconvLayers = [nn.Linear(self.latentDim,self.ch*self.sz*self.sz),
                        Reshape(self.ch,self.sz,self.sz)]
        for idx in range(0,repeat):
            deconvLayers.append(norm(ch))
            deconvLayers.append(nn.ConvTranspose2d(ch,ch//2, 
                                                   kernel_size=conv_k,stride=2,
                                                   padding=1, output_padding=1))
            deconvLayers.append(self.act)
            ch = ch//2
        
        deconvLayers.append(norm(ch))
        deconvLayers.append(nn.ConvTranspose2d(ch,int(self.output_shape[0]), 
                                               kernel_size=conv_k,stride=2,
                                               padding=1, output_padding=1))  
        
        if use_sigmoid_output_layer:
            deconvLayers.append(nn.Sigmoid())
            
        self.generator = nn.Sequential(*deconvLayers).to(device)
        # input data must be square
        
        
    def forward(self, x, p_x):
        z = self.encoder(x)
        z[:, -p_x.size(1):] = p_x
        x = self.generator(z)
        if self.return_z:
            return x, z
        else:
            return x        
        
        
class MLP(nn.Module):
    def __init__(self, X, hiddenLayerSizes = [1024], activation=nn.ELU()):
        super(MLP,self).__init__()
        
        self.activation = activation
        self.inputSize = X.shape[1:]
        self.modules = []
        self.modules.append(nn.Linear(np.prod(self.inputSize),hiddenLayerSizes[0]))
        self.modules.append(self.activation)
        for idx,sz in enumerate(hiddenLayerSizes[:-1]):
            self.modules.append(nn.Linear(hiddenLayerSizes[idx],hiddenLayerSizes[idx+1]))
            self.modules.append(self.activation)
                               
        self.modules.append(nn.Linear(hiddenLayerSizes[-1],np.prod(self.inputSize)))
        self.layers = nn.Sequential(*self.modules)
                                
        
    def forward(self,x):
        x = x.view(x.shape[0],-1)
        x = self.layers(x)
        x = x.view(x.shape[0],self.inputSize[0],self.inputSize[1],self.inputSize[2])
        return x

### Playground

In [4]:
bz = 8
latentDim = 16
filters = 128
num_conv = 4
num_Gconv= 128

In [5]:
X = torch.randn([bz, 1, 128, 128]) #, torch.Size([32, 3]))
X.shape

torch.Size([8, 1, 128, 128])

In [6]:
p_x = torch.ones((bz,2))
p_x.shape

torch.Size([8, 2])

In [15]:
model = ConvDeconvFactor2(X, filters=128, latentDim=512, num_conv=2, repeat=2, norm=nn.Identity)
for mod in model.modules():
    print(mod)
    assert type(mod)!=nn.BatchNorm2d, type(mod)

ConvDeconvFactor2(
  (act): LeakyReLU(negative_slope=0.01)
  (encoder): Sequential(
    (0): Identity()
    (1): Conv2d(1, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (2): LeakyReLU(negative_slope=0.01)
    (3): Identity()
    (4): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): LeakyReLU(negative_slope=0.01)
    (6): Identity()
    (7): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (8): Reshape()
    (9): Linear(in_features=131072, out_features=512, bias=True)
  )
  (generator): Sequential(
    (0): Linear(in_features=512, out_features=131072, bias=True)
    (1): Reshape()
    (2): Identity()
    (3): ConvTranspose2d(512, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (4): LeakyReLU(negative_slope=0.01)
    (5): Identity()
    (6): ConvTranspose2d(256, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (7): LeakyReLU(negative_slope=0.01)
    (8): I

In [None]:
model = ConvDeconvFactor2(X, filters=128, latentDim=512, num_conv=2, repeat=2)

In [None]:
model

In [None]:
model.ch, model.sz

In [None]:
model.encoder

In [None]:
model.generator

In [None]:
Xhat, z = model(X,p_x)
Xhat.shape, z.shape

(torch.Size([8, 1, 128, 128]), torch.Size([8, 512]))

In [None]:
printNumModelParams(model)
printNumModelParams(model.encoder)
printNumModelParams(model.generator)

In [None]:
for repeat in range(0,6):
    model = ConvDeconvFactor2(X, filters=128, latentDim=128, num_conv=2, repeat=repeat)
    print(repeat)
    printNumModelParams(model)
    print('-'*80)

In [None]:
p = torch.rand([bz,3])
p.shape

In [None]:
C = convBlock(num_conv=num_conv,in_channels=X.shape[1],filters=filters)
C

In [None]:
out = C(X)
out.shape

In [None]:
CTB = convTransBlock(num_conv=num_conv, filters=128, act=nn.LeakyReLU(), skip_connection=False, stack=True)
XX = torch.randn([bz, 128,8,6]) #, torch.Size([32, 3]))
XX.shape

In [None]:
out = CTB(XX)
out.shape

##### Encoder

In [None]:
E = Encoder(X,filters,latentDim,num_conv=num_conv)
E

In [None]:
printNumModelParams(E)

In [None]:
out = E(X)
out.shape

In [None]:
output_shape = torch.tensor(X.shape[1:])
output_shape

##### Generator

In [None]:
z = torch.randn(bz, latentDim + p.shape[1])
z.shape

In [None]:
G = Generator(z, 128, output_shape,
                 num_conv=2, conv_k=3, last_k=3, repeat=0, 
                 skip_connection=False, act=nn.LeakyReLU(),stack=True)
G

In [None]:
G.output_shape[0]

In [None]:
printNumModelParams(G)

In [None]:
out = G(z)
out.shape

##### AE

In [None]:
# EXPORT
class GeneratorOld(nn.Module):
    def __init__(self,z, filters, output_shape,
                 num_conv=4, conv_k=3, last_k=3, repeat=0, 
                 skip_connection=False, act=nn.LeakyReLU()):
        super(GeneratorOld,self).__init__()
        if repeat == 0:
            repeat_num = int(np.log2(torch.max(output_shape[1:]))) - 2
        else:
            repeat_num = repeat
        x0_shape = [filters] + [i//np.power(2, repeat_num-1) for i in output_shape[1:]]
        self.x0_shape = x0_shape
        num_output = int(np.prod(x0_shape))
        self.linear = nn.Linear(z.shape[1], num_output)
        convTransBlockLayers = []
        for i in range(repeat_num-1):
            convTransBlockLayers.append(convTransBlock(num_conv,filters,act,skip_connection))
        self.convTransBlockLayers = nn.Sequential(*convTransBlockLayers)
        self.lastConv = nn.Conv2d(filters,int(output_shape[0]),kernel_size=3, stride=1,padding=1)

    def forward(self, x):
        x = self.linear(x)
        x = x.view(-1,self.x0_shape[0],self.x0_shape[1],self.x0_shape[2])
        x = self.convTransBlockLayers(x)
        #x = self.lastConv(x)
        return x

In [None]:
# from nbdev.export import notebook2script
# notebook2script()

### Following two cells show that sigmoid is working

#### the output goes from having a max/min of .02/.01 to .5/.49, which is about the value of the sigmoid function near zero

In [None]:
repeat = 0
skip_connection = False
stack = False
createStreamFcn = False
model = AE_xhat_zV2(X, filters, latentDim, num_conv, repeat, 
                 skip_connection, stack, conv_k=3, last_k=3, 
                 act=nn.LeakyReLU(), return_z=True, stream=createStreamFcn, device='cpu',
                norm=nn.Identity, sigmoid_out = False)
X_out, z_out = model(X,p_x)
X_out.shape, z_out.shape, X_out.max(), X_out.min()

In [None]:
repeat = 0
skip_connection = False
stack = False
createStreamFcn = False
model = AE_xhat_zV2(X, filters, latentDim, num_conv, repeat, 
                 skip_connection, stack, conv_k=3, last_k=3, 
                 act=nn.LeakyReLU(), return_z=True, stream=createStreamFcn, device='cpu',
                norm=nn.Identity, sigmoid_out = True)
X_out, z_out = model(X,p_x)
X_out.shape, X_out.max(), X_out.min()

In [None]:
repeat = 0
skip_connection = False
stack = False
createStreamFcn = False
model = AE_xhat_zV2(X, filters, latentDim, num_conv, repeat, 
                 skip_connection, stack, conv_k=3, last_k=3, 
                 act=nn.LeakyReLU(), return_z=True, stream=createStreamFcn, device='cpu')

In [None]:
model(X,p_x)[1][:,-2:]

In [None]:
out = E(X)
out.shape

In [None]:
X.shape

In [None]:
Xhat,p = model(X)
Xhat.shape,p.shape

In [None]:
# Loss Function
loss_func = torch.nn.MSELoss()
loss_func

In [None]:
mseVal = loss_func(X,Xhat)
mseVal

### Build AE_no_P

In [None]:
z = E(X)
z.shape

In [None]:
G = Generator(z, 128, output_shape,
                 num_conv=2, conv_k=3, last_k=3, repeat=0, 
                 skip_connection=False, act=nn.LeakyReLU(),stack=True)
G

In [None]:
model = AE_no_P(E,G)

In [None]:
out = model(X)
out.shape

### Build AE_xhat_z

In [None]:
m = AE_xhat_z(E,G)
m

In [None]:
ct = nn.ConvTranspose1d(1,1,kernel_size=13,stride=4,padding=6,output_padding=3)

In [None]:
c = nn.Conv2d(2,2,kernel_size=13,stride=4,padding=6)
out = c(X)
out.shape

In [None]:
out = c(out)
out.shape

In [None]:
out = c(out)
out.shape

In [None]:
x_shape = X.shape[1:]
repeat_num = int(np.log2(np.max(x_shape[1:]))) - 2
repeat_num

In [None]:
E = Encoder_LK(X,16,16)
E

In [None]:
out = E(X)
out.shape

In [None]:
D = Decoder_LK(X,16,16)

In [None]:
xhat = D(out)

In [None]:
xhat.shape

In [None]:
M = AE_LK(E,D)

In [None]:
xhat,z = M(X)

In [None]:
xhat.shape,z.shape 

In [None]:
loss_func(X,xhat)