In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

In [6]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=1, padding=3)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=1, padding=2)
        self.conv3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1)
        # Add more convolutional layers if needed
        self.shuffle = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=1, stride=1)
        # Assume the size of residual blocks output is 256xHf'xWf'

    def forward(self, x):
        x = self.conv1(x)
        x = nn.ReLU()(x)
        x = self.conv2(x)
        x = nn.ReLU()(x)
        x = self.conv3(x)
        x = nn.ReLU()(x)
        # Shuffle operation
        x = self.shuffle(x)
        return x

class ClassificationModule(nn.Module):
    def __init__(self, input_size, num_classes):
        super(ClassificationModule, self).__init__()
        self.fc1 = nn.Linear(input_size, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = nn.ReLU()(x)
        x = self.fc2(x)
        return x

In [3]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(in_channels=32, out_channels=3, kernel_size=3, stride=1, padding=1)
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')  # Upsampling layer

    def forward(self, x):
        x = self.upsample(x)
        x = self.conv1(x)
        x = nn.ReLU()(x)
        x = self.conv2(x)
        x = nn.ReLU()(x)
        x = self.upsample(x)
        x = self.conv3(x)
        x = nn.ReLU()(x)
        x = self.conv4(x)
        x = nn.LeakyReLU()(x)  # Using LeakyReLU for the final layer
        return x

In [4]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(in_channels=32, out_channels=3, kernel_size=3, stride=1, padding=1)
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')  # Upsampling layer

    def forward(self, x):
        x = self.conv1(x)
        x = nn.ReLU()(x)
        x = self.upsample(x)
        x = self.conv2(x)
        x = nn.ReLU()(x)
        x = self.upsample(x)
        x = self.conv3(x)
        x = nn.ReLU()(x)
        x = self.conv4(x)
        x = nn.Tanh()(x)  # Using Tanh for the final layer to output values in the range [-1, 1]
        return x

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1)
        self.conv4 = nn.Conv2d(in_channels=256, out_channels=1, kernel_size=4, stride=1, padding=0)

    def forward(self, x):
        x = self.conv1(x)
        x = nn.LeakyReLU(0.2)(x)
        x = self.conv2(x)
        x = nn.LeakyReLU(0.2)(x)
        x = self.conv3(x)
        x = nn.LeakyReLU(0.2)(x)
        x = self.conv4(x)
        x = torch.sigmoid(x)  # Using Sigmoid activation for binary classification
        return x

In [8]:
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.encoder = Encoder()
        self.generator = Generator()
        self.discriminator = Discriminator()
        
        # Get the output size of the encoder
        self.encoder.eval()  # Set the encoder to evaluation mode to get the output size
        with torch.no_grad():
            sample_input = torch.randn(1, 3, 256, 256)  # Assuming input size of 256x256
            encoded_features = self.encoder(sample_input)
            input_size = encoded_features.view(encoded_features.size(0), -1).size(1)  # Flattened size

        self.classification_module = ClassificationModule(input_size, num_classes=10)  # Assuming 10 classes for classification

    def forward(self, x):
        # Forward pass through the entire network
        encoded_features = self.encoder(x)
        classification_output = self.classification_module(encoded_features)
        generated_output = self.generator(encoded_features)
        discriminator_output = self.discriminator(generated_output)
        return classification_output, generated_output, discriminator_output

# Define loss functions
classification_criterion = nn.CrossEntropyLoss()
generator_criterion = nn.MSELoss()
discriminator_criterion = nn.BCELoss()

# Initialize the network
network = Network()

# Define optimizers for each component
optimizer_classification = optim.Adam(list(network.encoder.parameters()) + 
                                       list(network.classification_module.parameters()), lr=0.001)
optimizer_generator = optim.Adam(network.generator.parameters(), lr=0.001)
optimizer_discriminator = optim.Adam(network.discriminator.parameters(), lr=0.001)


In [None]:
def train_network(network, train_loader, num_epochs, real_labels, fake_labels, target_data, 
                  classification_criterion, generator_criterion, discriminator_criterion, 
                  optimizer_classification, optimizer_generator, optimizer_discriminator, log_interval=100):
    classification_losses = []
    generator_losses = []
    discriminator_losses = []
    
    for epoch in range(num_epochs):
        for batch_idx, (input_data, target_labels) in enumerate(train_loader):
            # Forward pass
            classification_output, generated_output, discriminator_output = network(input_data)

            # Compute classification loss
            classification_loss = classification_criterion(classification_output, target_labels)

            # Compute generator loss
            generator_loss = generator_criterion(generated_output, target_data)

            # Compute discriminator loss
            real_loss = discriminator_criterion(discriminator_output, real_labels)
            fake_output = network.discriminator(target_data.detach())  # Detach to prevent backpropagation through generator
            fake_loss = discriminator_criterion(fake_output, fake_labels)
            discriminator_loss = real_loss + fake_loss

            # Backward and optimize
            optimizer_classification.zero_grad()
            optimizer_generator.zero_grad()
            optimizer_discriminator.zero_grad()

            classification_loss.backward()
            generator_loss.backward()
            discriminator_loss.backward()

            optimizer_classification.step()
            optimizer_generator.step()
            optimizer_discriminator.step()

            # Append losses to lists
            classification_losses.append(classification_loss.item())
            generator_losses.append(generator_loss.item())
            discriminator_losses.append(discriminator_loss.item())

            # Print loss statistics
            if batch_idx % log_interval == 0:
                print('Epoch [{}/{}], Batch [{}/{}], '
                      'Classification Loss: {:.4f}, '
                      'Generator Loss: {:.4f}, '
                      'Discriminator Loss: {:.4f}'.format(
                          epoch+1, num_epochs, batch_idx+1, len(train_loader),
                          classification_loss.item(),
                          generator_loss.item(),
                          discriminator_loss.item()))
                
                # Print additional feedback
                print('Avg Classification Loss: {:.4f}, '
                      'Avg Generator Loss: {:.4f}, '
                      'Avg Discriminator Loss: {:.4f}'.format(
                          sum(classification_losses) / len(classification_losses),
                          sum(generator_losses) / len(generator_losses),
                          sum(discriminator_losses) / len(discriminator_losses)))

    return classification_losses, generator_losses, discriminator_losses
