In [1]:
import torch
import torch.nn as nn

In [2]:
class ResnetBlock(nn.Module):
    
    def __init__(self, channels=64):
        
        super().__init__()
        
        self.identity = nn.Sequential(
            
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(channels),
            nn.PReLU(),
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(channels)
        )

        
    def forward(self, x):
        
        identity = self.identity(x)
        x = identity + x
        return x
    

In [3]:
class UpSampleBlock(nn.Module):
    
    def __init__(self, up_scale = 2, channels = 64):
        
        super().__init__()
        
        self.model = nn.Sequential(
                
          
        
            nn.Conv2d(channels, channels*(up_scale**2), kernel_size=3, padding=1),
            nn.PixelShuffle(up_scale),    
          
        
            nn.PReLU()
        )
        
        
    def forward(self, x):
        
        x = self.model(x)
        return x

In [17]:
class generator(nn.Module):
    
    def __init__(self, B_numResnetBlock = 16, in_channels = 3, step_channels = 64):
        
        super().__init__()
        
        self.init_model = nn.Sequential(
            nn.Conv2d(in_channels, step_channels, kernel_size=9, padding=4),
            nn.PReLU()
        )
        
        mid_model = []
        
        for i in range(B_numResnetBlock):
            mid_model.append(ResnetBlock(step_channels))
            
        mid_model.append(
            nn.Sequential(
                nn.Conv2d(step_channels, step_channels, kernel_size=3, padding=1),
                nn.BatchNorm2d(step_channels)
            )
        )
        
        self.mid_model = nn.Sequential(*mid_model)
        
        end_model = []
        
        self.r = 2
        for i in range(self.r):
            end_model.append(UpSampleBlock(up_scale = 2, channels = step_channels))
        
        end_model.append(
            nn.Sequential(
                nn.Conv2d(step_channels, in_channels, kernel_size=9, padding=4)
            )
        )
        
        self.end_model = nn.Sequential(*end_model)
        
        self._weight_initializer()
    
    
    def forward(self, x):
        
        x = self.init_model(x)
        skip_connection = self.mid_model(x)
        x = skip_connection + x
        x = self.end_model(x)
        
        return x
            
            
            
    def _weight_initializer(self):
        r"""Default weight initializer for all generator models.
        Models that require custom weight initialization can override this method
        """
        for m in self.modules():
            if isinstance(m, nn.ConvTranspose2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1.0)
                nn.init.constant_(m.bias, 0.0)
        

In [28]:
class discriminator(nn.Module):
    
    def __init__(self, in_channels = 3, step_channels = 64):
        
        super().__init__()
        
        model = []
        
        model.append(
            nn.Sequential(
                nn.Conv2d(in_channels, step_channels, kernel_size=3, padding=1),
                nn.LeakyReLU(.2),
                nn.Conv2d(step_channels, step_channels, kernel_size=3, stride=2, padding=1),
                nn.BatchNorm2d(step_channels),
                nn.LeakyReLU(.2)
            )
        )
        
        self.expansion = step_channels
        
        for i in range(3):
            
            model.append(
                nn.Sequential(
                    nn.Conv2d(self.expansion, self.expansion*2, kernel_size=3, padding=1),
                    nn.BatchNorm2d(self.expansion*2),
                    nn.LeakyReLU(.2),
                    nn.Conv2d(self.expansion*2, self.expansion*2, kernel_size=3, stride=2, padding=1),
                    nn.BatchNorm2d(self.expansion*2),
                    nn.LeakyReLU(.2)    
          
        
                )
            )
            
            self.expansion = self.expansion*2
            
        
        model.append(
            nn.Sequential(
                nn.AdaptiveAvgPool2d(1),
                nn.Conv2d(self.expansion, self.expansion*2, kernel_size=1),
                nn.LeakyReLU(.2),
                nn.Conv2d(self.expansion*2, 1, kernel_size=1),
                nn.Sigmoid()
            )
        )
        
        
        self.model = nn.Sequential(*model)
        
        self._weight_initializer()
        

    
    def forward(self, x):
        
        x = self.model(x)
        return x.view(-1)
        
        
        
    def _weight_initializer(self):
        r"""Default weight initializer for all generator models.
        Models that require custom weight initialization can override this method
        """
        for m in self.modules():
            if isinstance(m, nn.ConvTranspose2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1.0)
                nn.init.constant_(m.bias, 0.0)
        