In [2]:
from cosmikyu import gan, config, model
from cosmikyu import nn as cnn
import numpy as np
import os
import torchvision.transforms as transforms
from torchvision import datasets
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import torch
import mlflow
import torchsummary

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [8]:
shape = (3,128,128)
genin_shape = (3,128,128)
discin_shape = (6,128,128)
unet_gen = model.ResUNET_Generator(
    genin_shape,
    nconv_layer=3,
    nconv_fc=64,
    ngpu=1,
    activation=[torch.nn.Tanh()],
    nin_channel = 3,
    nout_channel = 1,
    nthresh_layer = 1,
).to(device="cpu")

torchsummary.summary(unet_gen, genin_shape, device="cpu")


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 128, 128]           1,792
       BatchNorm2d-2         [-1, 64, 128, 128]             128
              ReLU-3         [-1, 64, 128, 128]               0
            Conv2d-4         [-1, 64, 128, 128]          36,928
            Conv2d-5         [-1, 64, 128, 128]           1,792
      ResUNetBlock-6         [-1, 64, 128, 128]               0
       BatchNorm2d-7         [-1, 64, 128, 128]             128
              ReLU-8         [-1, 64, 128, 128]               0
            Conv2d-9          [-1, 128, 64, 64]          73,856
      BatchNorm2d-10          [-1, 128, 64, 64]             256
             ReLU-11          [-1, 128, 64, 64]               0
           Conv2d-12          [-1, 128, 64, 64]         147,584
           Conv2d-13          [-1, 128, 64, 64]          73,856
      BatchNorm2d-14          [-1, 128,

In [3]:
shape = (3,128,128)
genin_shape = (256,)
discin_shape = (6,128,128)
unet_gen = model.ResUNET_DCGAN_Generator(
    shape,
    nconv_layer=3,
    nconv_fc=64,
    ngpu=1,
    activation=[torch.nn.Tanh()],
    nin_channel = 3,
    nout_channel = 3,
    nthresh_layer = 1,
    latent_dim=256,
).to(device="cpu")

torchsummary.summary(unet_gen, genin_shape, device="cpu")


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                [-1, 49152]      12,632,064
           Reshape-2          [-1, 3, 128, 128]               0
       BatchNorm2d-3          [-1, 3, 128, 128]               6
         LeakyReLU-4          [-1, 3, 128, 128]               0
            Conv2d-5         [-1, 64, 128, 128]           1,792
       BatchNorm2d-6         [-1, 64, 128, 128]             128
         LeakyReLU-7         [-1, 64, 128, 128]               0
            Conv2d-8         [-1, 64, 128, 128]          36,928
            Conv2d-9         [-1, 64, 128, 128]           1,792
     ResUNetBlock-10         [-1, 64, 128, 128]               0
      BatchNorm2d-11         [-1, 64, 128, 128]             128
        LeakyReLU-12         [-1, 64, 128, 128]               0
           Conv2d-13          [-1, 128, 64, 64]          73,856
      BatchNorm2d-14          [-1, 128,

In [28]:
shape = (5,128,128)
genin_shape = (256,)
discin_shape = (6,128,128)
unet_gen = model.ResDCGAN_Generator(
    shape,
    latent_dim=256,
    nconv_layer=3,
    nconv_fc=64,
    ngpu=1,
    activation=[torch.nn.Tanh()],
    kernal_size=3, stride=2,
    padding=1, output_padding=0
).to(device="cpu")

torchsummary.summary(unet_gen, genin_shape, device="cpu")

up0
up1
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                [-1, 65536]      16,842,752
           Reshape-2          [-1, 256, 16, 16]               0
       BatchNorm2d-3          [-1, 256, 16, 16]             512
         LeakyReLU-4          [-1, 256, 16, 16]               0
   ConvTranspose2d-5          [-1, 128, 31, 31]         295,040
ResUNetUPInterface-6          [-1, 128, 31, 31]               0
       BatchNorm2d-7          [-1, 128, 31, 31]             256
         LeakyReLU-8          [-1, 128, 31, 31]               0
            Conv2d-9          [-1, 128, 31, 31]         147,584
      BatchNorm2d-10          [-1, 128, 31, 31]             256
        LeakyReLU-11          [-1, 128, 31, 31]               0
           Conv2d-12          [-1, 128, 31, 31]         147,584
           Conv2d-13          [-1, 128, 31, 31]         147,584
      BatchNorm2d-14          [

In [18]:
unet_gen.model_dict["up0_int"]

ResUNetUPInterface(
  (model): Sequential(
    (0): ConvTranspose2d(256, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  )
)

In [None]:
init_layers = [nn.Linear(256, self.nin_channel * self.shape[-1] ** 2),
          cnn.Reshape((self.nin_channel, self.shape[-1], self.shape[-1])),
          nn.BatchNorm2d(self.nin_channel), kernal_size=4, stride=2,
            padding=1, output_padding=0
          nn.LeakyReLU(0.2, inplace=True)]


In [6]:
shape = (3,128,128)
genin_shape = (3,224,224)
discin_shape = (6,128,128)
unet_gen = model.ResVAE_Generator(
    genin_shape,
    nconv_layer=3,
    nconv_fc=64,
    ngpu=1,
    activation=[torch.nn.Tanh()],
    nin_channel = 3,
    nout_channel = 1,
    nthresh_layer = 1,
).to(device="cpu")

torchsummary.summary(unet_gen, genin_shape, device="cpu")


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 224, 224]           1,792
       BatchNorm2d-2         [-1, 64, 224, 224]             128
              ReLU-3         [-1, 64, 224, 224]               0
            Conv2d-4         [-1, 64, 224, 224]          36,928
            Conv2d-5         [-1, 64, 224, 224]           1,792
      ResUNetBlock-6         [-1, 64, 224, 224]               0
       BatchNorm2d-7         [-1, 64, 224, 224]             128
              ReLU-8         [-1, 64, 224, 224]               0
            Conv2d-9        [-1, 128, 112, 112]          73,856
      BatchNorm2d-10        [-1, 128, 112, 112]             256
             ReLU-11        [-1, 128, 112, 112]               0
           Conv2d-12        [-1, 128, 112, 112]         147,584
           Conv2d-13        [-1, 128, 112, 112]          73,856
      BatchNorm2d-14        [-1, 128, 1