In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

In [15]:
class Generator(nn.Module):
    def __init__(self, noise_dim, label_dim, output_dim):
        super(Generator, self).__init__()
        self.label_emb = nn.Embedding(label_dim, label_dim)
        self.net = nn.Sequential(
            nn.Linear(noise_dim + label_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, output_dim),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        labels = self.label_emb(labels)
        gen_input = torch.cat((noise, labels), dim=1)
        return self.net(gen_input)


In [24]:
class Critic(nn.Module):
    def __init__(self, input_dim, label_dim):
        super(Critic, self).__init__()
        self.label_emb = nn.Embedding(label_dim, label_dim)
        self.net = nn.Sequential(
            nn.Linear(input_dim + label_dim, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.Dropout(0.3),
            nn.Linear(256, 1)
        )

    def forward(self, inputs, labels):
        labels = self.label_emb(labels).view(labels.size(0), -1)
        inputs = inputs.view(inputs.size(0), -1)  # Flatten the inputs
        disc_input = torch.cat((inputs, labels), dim=1)
        return self.net(disc_input)


In [25]:
def gradient_penalty(C, real_data, fake_data, real_labels, lambda_gp):
    batch_size = real_data.size(0)
    epsilon = torch.rand(batch_size, 1, device=real_data.device)
    interpolates = epsilon * real_data + ((1 - epsilon) * fake_data)
    interpolates.requires_grad_(True)
    
    interpolated_labels = real_labels.view(batch_size, -1)  # Correct the shape of labels
    d_interpolates = C(interpolates, interpolated_labels)
    fake = torch.ones(d_interpolates.size(), device=real_data.device)
    
    gradients = torch.autograd.grad(
        outputs=d_interpolates, inputs=interpolates,
        grad_outputs=fake, create_graph=True, retain_graph=True, only_inputs=True
    )[0]
    
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * lambda_gp
    return gradient_penalty


In [31]:
def train_gan(data, labels, epochs=300, batch_size=32, noise_dim=100, lr=3e-4, lambda_gp=10):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    data = data.float()
    labels = labels.long()
    
    dataset = TensorDataset(data, labels)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    G = Generator(noise_dim, 8, data.shape[1] * data.shape[2]).to(device)
    C = Critic(data.shape[1] * data.shape[2], 8).to(device)
    
    optimizer_g = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
    optimizer_c = optim.Adam(C.parameters(), lr=lr, betas=(0.5, 0.999))
    
    for epoch in range(epochs):
        for i, (real_data, real_labels) in enumerate(dataloader):
            real_data = real_data.to(device).view(-1, data.shape[1] * data.shape[2])
            real_labels = real_labels.to(device)
            
            batch_size = real_data.size(0)
            noise = torch.randn(batch_size, noise_dim).to(device)
            fake_labels = torch.randint(0, 8, (batch_size,)).to(device)
            fake_data = G(noise, fake_labels)
            
            real_targets = torch.ones(batch_size, 1).to(device)
            fake_targets = torch.zeros(batch_size, 1).to(device)
            
            # Train Critic
            optimizer_c.zero_grad()
            real_output = C(real_data, real_labels)
            fake_output = C(fake_data.detach(), fake_labels)
            gp = gradient_penalty(C, real_data, fake_data, real_labels, lambda_gp)
            loss_c = -torch.mean(real_output) + torch.mean(fake_output) + gp
            loss_c.backward(retain_graph=True)
            optimizer_c.step()
            
            # Train Generator every 5 iterations of Critic
            if i % 5 == 0:
                optimizer_g.zero_grad()
                fake_output = C(fake_data, fake_labels)
                loss_g = -torch.mean(fake_output)
                loss_g.backward(retain_graph=True)
                optimizer_g.step()
        
        if (epoch + 1) % 20 == 0:
            print(f'Epoch [{epoch + 1}/{epochs}], Critic Loss: {loss_c.item():.4f}, Generator Loss: {loss_g.item():.4f}')
            torch.save(G.state_dict(), f'cwgan_generator_{epoch + 1}.pth')

    return G, C

In [32]:
data = torch.load('data_mm_user1.pth')  
labels = torch.load('labels_user1.pth')  

print(f'Data size: {data.shape}')
print(f'Labels size: {labels.shape}')

G, C = train_gan(data, labels)

  data = torch.load('data_mm_user1.pth')
  labels = torch.load('labels_user1.pth')


Data size: torch.Size([1800, 50, 30])
Labels size: torch.Size([1800])
Epoch [20/300], Critic Loss: -128739.6562, Generator Loss: -3486.6536
Epoch [40/300], Critic Loss: -95570.0781, Generator Loss: -3522.4817
Epoch [60/300], Critic Loss: -80596.2266, Generator Loss: -938.2299
Epoch [80/300], Critic Loss: -162228.7500, Generator Loss: 90707.3125
Epoch [100/300], Critic Loss: -64606.4805, Generator Loss: 247655.0938
Epoch [120/300], Critic Loss: -1019356.1875, Generator Loss: 1068496.0000
Epoch [140/300], Critic Loss: -249934.1250, Generator Loss: 3819869.5000
Epoch [160/300], Critic Loss: -969688.4375, Generator Loss: 5113688.0000


KeyboardInterrupt: 