In [9]:
import torch 
import torch.nn as nn

In [11]:
class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(CNNBlock,self).__init__()
        ## define intial and base layer 
        self.conv = nn.Sequential(
            nn,conv2d(
                in_channels, out_channels, 4 , stride, 1 , bias= False, padding_mode="reflect"
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )
    def forword(self,x): 
        return self.conv(x)

In [24]:
class Discriminator(nn.Module):
    
    # features=[64,128,256,512] it is number of filters/ channels for each layer of the discriminator
    # in_channels=3 --> Colored input ( RGB ) 
    def __init__(self, in_channels=3,features=[64,128,256,512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(
                # the discriminator takes 2 inputs conctatenated together ( input image , target/output image ) 
                in_channels*2,
                feature[0],
                kernal_size=4,
                stride=2,
                padding=1,
                # This padding strategy mirrors the values at the tensor edges rather than using zeros or other constants
                padding_mode="reflect",
            ),
            nn.LeakyReLU(0.2),
        )
        
        layers=[]
        in_channels=features[0]
        for feature in features[1:]:
            # stride=1 in the last layer ensures we don't downsample further
            layers.append(in_channels,feature,stride = 1 if feature == features[-1] else 2 ),
            in_channels = feature
        layers.append(
            conv2d(in_channels,1,kernal_size = 4, stride = 1 , padding = 1 , padding_mode= " reflect"),
        )
        # The * operator is called the "unpacking operator", and it's used to split a list into individual elements
        self.model = nn.Sequential(*layers)
        
    def forword(self,x,y):
        #  Concatenates x and y along channel dimension
        x= torch.cat([x,y],dim=1)
        # Passes the concatenated input through the first layer 
        x= self.initial(x)
        # Processes the feature maps through the remaining layers
        x= self.model(x)
        return x 

    def test():
       x=torch.randn((1,3,256,256))
       y=torch.rand((1,3,256,256))
       model = Distriminator(in_channels=3)
       preds = model(x,y)
       print(model)
       print(preds.shape)

    if __name__ == " __main__":
        test()