In [215]:
import torch
import torchvision

In [216]:
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
train_data_raw=torchvision.datasets.MNIST('/files/', train=True, download=True, transform=ToTensor())
test_data_raw=torchvision.datasets.MNIST('/files/', train=False, download=True,transform=ToTensor())
train_data=DataLoader(train_data_raw,batch_size=512,shuffle=True)
test_data=DataLoader(test_data_raw,batch_size=512,shuffle=True)

In [217]:
test=next(iter(train_data))
print(test[0][0].shape)
class_names=train_data_raw.classes
print(class_names)

torch.Size([1, 28, 28])
['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']


In [218]:
import torch.nn as nn
class Discriminator(nn.Module):
    
    def __init__(self, output_shape: int):
        super().__init__()
        self.layers=nn.Sequential(
            nn.Linear(784, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 1)
        )
    
    def forward(self, x: torch.Tensor):
        x=x.view(-1,784)
        return torch.sigmoid(self.layers(x))

discriminator=Discriminator(1)


In [219]:
class Generator(nn.Module):
    def __init__(self,input_dim):
        super(Generator, self).__init__() 
        self.layers=nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 784)
        )
    
    # forward method
    def forward(self, x): 
        return torch.tanh(self.layers(x)).view(-1,1,28,28)

generator=Generator(128)

In [220]:
def gen_train_step(generator:  torch.nn.Module,
                   discriminator: torch.nn.Module,
               dataloader: torch.utils.data.DataLoader, 
               criterion: torch.nn.Module, 
               gen_optimizer: torch.optim.Optimizer):
        generator.zero_grad()

        z = torch.randn(64, 128)
        y = torch.ones(64, 1)

        gen_output = generator(z)
        disc_output = discriminator(gen_output)
        gen_loss = criterion(disc_output, y)

        # gradient backprop & optimize ONLY G's parameters
        gen_loss.backward()
        gen_optimizer.step()


def train_step(discriminator: torch.nn.Module,
               generator:  torch.nn.Module,
               dataloader: torch.utils.data.DataLoader, 
               criterion: torch.nn.Module, 
               disc_optimizer: torch.optim.Optimizer,
               gen_optimizer: torch.optim.Optimizer):
    # Put model in train mode
    discriminator.train()
    generator.train()
    
    # Setup train loss and train accuracy values
    train_loss, train_acc = 0, 0
    
    # Loop through data loader data batches
    for batch, (X, y) in enumerate(dataloader):

        
        disc_optimizer.zero_grad()
        
        # 1. calc real loss
        y_pred = discriminator(X)
        ones=torch.ones(y_pred.size(0),1)
        real_loss = criterion(y_pred, ones)
        
        
        #2. calc fake loss
        z = torch.randn(X.size(0), 128)
        fake_output=generator(z)
        fake_pred=discriminator(fake_output)
        zeroes=torch.zeros(fake_pred.size(0),1)
        fake_loss=criterion(fake_pred, zeroes)
    



        

       

        # 3. Loss backward
        real_loss.backward()
        fake_loss.backward()

        # 4. Optimizer step
        disc_optimizer.step()

        
       

    # Adjust metrics to get average loss and accuracy per batch 
    train_loss = train_loss / len(dataloader)
    train_acc = train_acc / len(dataloader)
    return train_loss, train_acc

In [221]:
from tqdm.auto import tqdm

# 1. Take in various parameters required for training and test steps
def train(discriminator: torch.nn.Module,
          generator:  torch.nn.Module,
          train_dataloader: torch.utils.data.DataLoader, 
          test_dataloader: torch.utils.data.DataLoader, 
          epochs: int = 5):
    loss_fn = nn.BCELoss() 
    disc_optimizer = torch.optim.Adam(params=discriminator.parameters(), lr=0.001)
    gen_optimizer=torch.optim.Adam(params=generator.parameters(), lr=0.001)
    
    # 2. Create empty results dictionary
    results = {"train_loss": [],
        "train_acc": [],
        "test_loss": [],
        "test_acc": []
    }
    
    # 3. Loop through training and testing steps for a number of epochs
    for epoch in tqdm(range(epochs)):
        train_loss, train_acc = train_step(discriminator=discriminator,
                                           generator=generator,
                                           dataloader=train_dataloader,
                                           criterion=loss_fn,
                                           disc_optimizer=disc_optimizer,
                                           gen_optimizer=gen_optimizer)
        
        
        # 4. Print out what's happening
        print(
            f"Epoch: {epoch+1} | "
            f"disc_train_loss: {train_loss:.4f} | "
            f"gen_train_loss: {train_acc:.4f} | "
        )

        # 5. Update results dictionary
        results["train_loss"].append(train_loss)
        results["train_acc"].append(train_acc)
       

    # 6. Return the filled results at the end of the epochs
    return results

In [None]:
# Set random seeds
torch.manual_seed(42) 
torch.cuda.manual_seed(42)

# Set number of epochs
NUM_EPOCHS = 1


# Setup loss function and optimizer



# Start the timer
from timeit import default_timer as timer 
start_time = timer()

# Train model_0 
model_0_results = train(discriminator=discriminator,
                        generator=generator, 
                        train_dataloader=train_data,
                        test_dataloader=test_data,
                        epochs=NUM_EPOCHS)

# End the timer and print out how long it took
end_time = timer()
print(f"Total training time: {end_time-start_time:.3f} seconds")

  0%|          | 0/1 [00:00<?, ?it/s]