In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

In [11]:
class Generator(nn.Module):
    def __init__(self, noise_dim, label_dim, output_dim):
        super(Generator, self).__init__()
        self.label_emb = nn.Embedding(label_dim, label_dim)
        self.net = nn.Sequential(
            nn.Linear(noise_dim + label_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, output_dim),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        labels = self.label_emb(labels)
        gen_input = torch.cat((noise, labels), dim=1)
        return self.net(gen_input)

In [23]:
class Discriminator(nn.Module):
    def __init__(self, input_dim, label_dim):
        super(Discriminator, self).__init__()
        self.label_emb = nn.Embedding(label_dim, label_dim)
        self.net = nn.Sequential(
            nn.Linear(input_dim + label_dim, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, inputs, labels):
        labels = self.label_emb(labels)
        disc_input = torch.cat((inputs, labels), dim=1)
        return self.net(disc_input)

In [10]:
def train_gan(data, labels, epochs=300, batch_size=32, noise_dim=100, lr=3e-4):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    data = data.float() #for precision
    labels = labels.long() #for embeddings (int64)
    
    dataset = TensorDataset(data, labels)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    G = Generator(noise_dim, 8, data.shape[1] * data.shape[2]).to(device) #takes noise_dim = 100, labels 8 as i/p and produces 30 * 50 shape o/p
    D = Discriminator(data.shape[1] * data.shape[2], 8).to(device) # 30 * 50 shape i/p and 8 labels, produces single probability as o/p
    
    criterion = nn.BCELoss()
    optimizer_g = optim.Adam(G.parameters(), lr=lr)
    optimizer_d = optim.Adam(D.parameters(), lr=lr)
    
    for epoch in range(epochs):
        #print(f'Starting epoch {epoch + 1}')
        for real_data, real_labels in dataloader:
            real_data = real_data.to(device).view(-1, data.shape[1] * data.shape[2])
            real_labels = real_labels.to(device)
            
            batch_size = real_data.size(0)
            noise = torch.randn(batch_size, noise_dim).to(device)
            fake_labels = torch.randint(0, 8, (batch_size,)).to(device)
            fake_data = G(noise, fake_labels)
            
            real_targets = torch.ones(batch_size, 1).to(device)
            fake_targets = torch.zeros(batch_size, 1).to(device)
            
            # Train Discriminator
            optimizer_d.zero_grad()
            outputs = D(real_data, real_labels) #discriminator o/p on real data
            loss_real = criterion(outputs, real_targets) #compare it with 1
            loss_real.backward()
            
            outputs = D(fake_data.detach(), fake_labels) #discriminator o/p on fake data
            loss_fake = criterion(outputs, fake_targets) #compare it with 0
            loss_fake.backward()
            optimizer_d.step()
            
            # Train Generator
            optimizer_g.zero_grad()
            outputs = D(fake_data, fake_labels)
            loss_g = criterion(outputs, real_targets)
            loss_g.backward()
            optimizer_g.step()
        
        #print(f'Epoch [{epoch + 1}/{epochs}] completed')
        
        if (epoch + 1) % 20 == 0:
            print(f'Epoch [{epoch + 1}/{epochs}], D Loss: {loss_real + loss_fake:.4f}, G Loss: {loss_g:.4f}')
            torch.save(G.state_dict(), f'generator_{epoch + 1}.pth')
            
    return G, D


In [None]:
data = torch.load('data_mm_user1.pth')  
labels = torch.load('labels_user1.pth')  

print(f'Data size: {data.shape}')
print(f'Labels size: {labels.shape}')

G, D = train_gan(data, labels)