In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.autograd import Variable
import copy


  from .autonotebook import tqdm as notebook_tqdm


In [6]:
class View(nn.Module):
    def __init__(self,shape):
        super().__init__()
        self.shape =(shape,)
    def forward(self,x):
        return x.view(*self.shape)
class FadeIn(nn.Module):
    def __init__(self,x_low,x_high):
        super().__init__()
        self.x_low = x_low
        self.x_high = x_high
        self.alpha = 0.5

    def update_alpha(self, delta):
        self.alpha = self.alpha + delta
        self.alpha = max(0, min(self.alpha, 1.0))

    def forward(self,x):
        y_low = self.x_low(x)
        y_high = self.x_high(x)

        return (1.0-self.alpha)*y_low+self.alpha*y_high
    


In [12]:
class myUtils(nn.Module):
    def __init__(self):
        super().__init__()
    ## ------ dist
    def conv_twice_for_dist(self,ch_in,ch_out):
        layers = nn.Sequential(
            nn.Conv2d(ch_in,ch_in,3,1,1),
            nn.ReLU(),
            nn.Conv2d(ch_in,ch_out,3,1,1),
            nn.ReLU(),        
            nn.AvgPool2d(kernel_size=2), 
        )
        return layers

    def from_rgb(self,resl):
        if resl <6:
            layers = nn.Sequential(
            nn.Conv2d(3,512,3,1,1),
            nn.ReLU(),
        )
        else:
            layers = nn.Sequential(
            nn.Conv2d(3,2**(14-resl),3,1,1),
            nn.ReLU(),
        )
        return layers

    def last_block(self,):
        layers = nn.Sequential(
            #[2,512,4,4]
            nn.Conv2d(512,512,3,1,1),
            nn.ReLU(),
            nn.Conv2d(512,512,4,1,0),
            nn.ReLU(),
        
            View((-1,512)),
            nn.Linear(512,1),
            nn.Sigmoid()
        )
        return layers    

    ## ------ gen
    def conv_twice_for_gen(self,ch_in,ch_out):
        layers = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(ch_in,ch_out,3,1,1),
            nn.ReLU(),
            nn.Conv2d(ch_out,ch_out,3,1,1),
            nn.ReLU(),        
        )
        return layers

    def start_block(self,):
        layers = nn.Sequential(
            #[2,512]
            View((-1,512,1,1)),
            nn.Conv2d(512,512,4,1,3),
            nn.ReLU(),
            nn.Conv2d(512,512,3,1,1),
            nn.ReLU(),
        )
        return layers

    def to_rgb(self,resl):
        if resl <6:
            layers = nn.Sequential(
            nn.Conv2d(512,3,3,1,1),
            nn.ReLU(),
        )
        else:
            layers = nn.Sequential(
            nn.Conv2d(2**(14-resl),3,3,1,1),
            nn.Tanh(),
        )
        return layers

    ## ------ both
    def deepcopy_module(self,module, target):
        new_module = nn.Sequential()
        for name, m in module.named_children():
            if name == target:
                new_module.add_module(name, m)                          # make new structure and,
                new_module[-1].load_state_dict(m.state_dict())         # copy weights
        return new_module

In [8]:
class Dist(myUtils):
    def __init__(self):
        super().__init__()
        self.name = None
        self.model = self.init_dist()
    def init_dist(self,):
        layers = nn.Sequential()
        layers.add_module('from_rgb',self.from_rgb(2)) 
        layers.add_module('last_block',self.last_block()) 
        #add_module을 안 하고 일단 layer만 리스트에 추가해도 나중에 이름 붙일 수 있나?
        return layers

    def intermidate_layer(self,resl): # resl=3부터 시작
        
        self.name = f"intermidate_{int(2**resl)}x{int(2**resl)}to{int(2**(resl-1))}x{int(2**(resl-1))}"

        if (6<=resl):
            layers=self.conv_twice_for_dist(2**(14-resl),2**(15-resl))
        
        elif (2<resl<6):
            layers=self.conv_twice_for_dist(512,512)

        return layers


    def grow_layer(self,resl):
        prev_rgb_block = self.deepcopy_module(self.model,'from_rgb')
        prev_head = nn.Sequential(nn.AvgPool2d(2),prev_rgb_block)

        next_head = nn.Sequential()
        next_head.add_module('next_rgb_block',self.from_rgb(resl))
        next_head.add_module('next_intermidate',self.intermidate_layer(resl))
        
        #fade in 적용
        new_model = nn.Sequential()
        new_model.add_module('fadein',FadeIn(prev_head,next_head))

        #나머지 부분 weight deepcopy
        for name,module in self.model.named_children():
            if name!= 'from_rgb':
                new_model.add_module(name,module)
                new_model[-1].load_state_dict(module.state_dict())
        print(f"resl={resl} new_model",new_model)
        self.model = new_model

    def flush_model(self,resl):
        high_head = self.model.fadein.x_high
        #print(f"resl={resl} high_head",high_head)
        rgb_block = self.deepcopy_module(high_head,'next_rgb_block')
        intermidate_block = self.deepcopy_module(high_head,'next_intermidate')
        print(f"==========intermidate",intermidate_block)
        new_model = nn.Sequential()
        new_model.add_module('from_rgb',rgb_block)   
        new_model.add_module(self.name,intermidate_block)  #이거 이름 뭘로 해야하지?  

        for name,module in self.model.named_children():
            if name!= 'fadein':
                new_model.add_module(name,module)
                new_model[-1].load_state_dict(module.state_dict())    

        self.model = new_model
        print(f"==========resl={resl} after flush",self.model)
        # rgb_block은 rgb_block으로 이름 붙이고, intermidate는 새이름 붙이기
        # 이후, 나머지 부분 weight deepcopy하며 new model 완성

    def forward(self,x):
        return self.model(x)

In [9]:
dist = Dist()
print(dist.model)
dummy = torch.ones((2,3,4,4))
print(dist.model(dummy).shape)
print("----------------------------------------")

dist.grow_layer(3)
dist.flush_model(3)
dummy = torch.ones((2,3,8,8))
print(dist.model(dummy).shape)
print("----------------------------------------")
dist.grow_layer(4)
dist.flush_model(4)
dummy = torch.ones((2,3,16,16))
print(dist.model(dummy).shape)
print("----------------------------------------")
dist.grow_layer(5)
dist.flush_model(5)
dummy = torch.ones((2,3,32,32))
print(dist.model(dummy).shape)
print("----------------------------------------")
dist.grow_layer(6)
dist.flush_model(6)
dummy = torch.ones((2,3,64,64))
print(dist.model(dummy).shape)
#grow -> flush ->grow ... 이렇게 번갈아가며 나와야 함

Sequential(
  (from_rgb): Sequential(
    (0): Conv2d(3, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
  )
  (last_block): Sequential(
    (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(512, 512, kernel_size=(4, 4), stride=(1, 1))
    (3): ReLU()
    (4): View()
    (5): Linear(in_features=512, out_features=1, bias=True)
    (6): Sigmoid()
  )
)
torch.Size([2, 1])
----------------------------------------
resl=3 new_model Sequential(
  (fadein): FadeIn(
    (x_low): Sequential(
      (0): AvgPool2d(kernel_size=2, stride=2, padding=0)
      (1): Sequential(
        (from_rgb): Sequential(
          (0): Conv2d(3, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU()
        )
      )
    )
    (x_high): Sequential(
      (next_rgb_block): Sequential(
        (0): Conv2d(3, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
      (next_intermidate): 

In [13]:
class Gen(myUtils):
    def __init__(self):
        super().__init__()
        self.name = None
        self.model = self.init_gen()  

    def init_gen(self,):
        layers = nn.Sequential()
        layers.add_module('start_block',self.start_block()) 
        layers.add_module('to_rgb',self.to_rgb(2)) 
        #add_module을 안 하고 일단 layer만 리스트에 추가해도 나중에 이름 붙일 수 있나?
        return layers

    
    def intermidate_layer(self,resl): # resl=3부터 시작
        
        
        self.name = f"intermidate_{int(2**(resl-1))}x{int(2**(resl-1))}to{int(2**resl)}x{int(2**resl)}"

        if (6<=resl):
            layers=self.conv_twice_for_gen(2**(15-resl),2**(14-resl))
        
        elif (2<resl<6):
            layers=self.conv_twice_for_gen(512,512)

        return layers

    def grow_layer(self,resl):
        
        prev_rgb_block =self.deepcopy_module(self.model,'to_rgb')
        prev_tail = nn.Sequential(nn.Upsample(scale_factor=2),prev_rgb_block)

        next_tail = nn.Sequential()
        next_tail.add_module('next_intermidate',self.intermidate_layer(resl))
        next_tail.add_module('next_rgb_block',self.to_rgb(resl))
        
        #나머지 부분 weight deepcopy
        new_model = nn.Sequential()
        for name,module in self.model.named_children():
            if name!= 'to_rgb':
                new_model.add_module(name,module)
                new_model[-1].load_state_dict(module.state_dict())
        
        #fade in 적용
        new_model.add_module('fadein',FadeIn(prev_tail,next_tail))
        print(f"resl={resl} new_model",new_model)
        self.model = new_model

    def flush_model(self,resl):
        
        #fade in 모듈에서 layer2만 복사
        # layer2에서 rgb_block, intermidate 분리
        high_tail = self.model.fadein.x_high
        #print(f"resl={resl} high_head",high_head)
        intermidate_block = self.deepcopy_module(high_tail,'next_intermidate')
        rgb_block = self.deepcopy_module(high_tail,'next_rgb_block')
        print(f"==========intermidate",intermidate_block)
        new_model = nn.Sequential()
        for module_name,module in self.model.named_children():
            if module_name!= 'fadein':
                new_model.add_module(module_name,module)
                new_model[-1].load_state_dict(module.state_dict())    
        
        # name이 겹침!! ->나중에 self.name으로 고치기
        new_model.add_module(self.name,intermidate_block)  
        new_model.add_module('to_rgb',rgb_block)   
        self.model = new_model
        print(f"==========resl={resl} after flush",self.model)
        # rgb_block은 rgb_block으로 이름 붙이고, intermidate는 새이름 붙이기
        # 이후, 나머지 부분 weight deepcopy하며 new model 완성

        pass

In [14]:
gen = Gen()

print(gen.model)
dummy = torch.ones((2,512))
print(gen.model(dummy).shape)
print("----------------------------------------")

gen.grow_layer(3)
gen.flush_model(3)
print(gen.model(dummy).shape)
print("----------------------------------------")
gen.grow_layer(4)
gen.flush_model(4)
print(gen.model(dummy).shape)
print("----------------------------------------")
gen.grow_layer(5)
gen.flush_model(5)
print(gen.model(dummy).shape)
print("----------------------------------------")
gen.grow_layer(6)
gen.flush_model(6)
print(gen.model(dummy).shape)
#grow -> flush ->grow ... 이렇게 번갈아가며 나와야 함

Sequential(
  (start_block): Sequential(
    (0): View()
    (1): Conv2d(512, 512, kernel_size=(4, 4), stride=(1, 1), padding=(3, 3))
    (2): ReLU()
    (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
  )
  (to_rgb): Sequential(
    (0): Conv2d(512, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
  )
)
torch.Size([2, 3, 4, 4])
----------------------------------------
resl=3 new_model Sequential(
  (start_block): Sequential(
    (0): View()
    (1): Conv2d(512, 512, kernel_size=(4, 4), stride=(1, 1), padding=(3, 3))
    (2): ReLU()
    (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
  )
  (fadein): FadeIn(
    (x_low): Sequential(
      (0): Upsample(scale_factor=2.0, mode=nearest)
      (1): Sequential(
        (to_rgb): Sequential(
          (0): Conv2d(512, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU()
        )
      )
    )
    (x_high): Sequential(