In [4]:
import torch
import os
import torch.nn as nn
from torch.utils.data import DataLoader
from cod.Dataset import MyDataset
from cod.models import resnet18
from cod.Transform import get_transforms

In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size=64
dirdata = "./data/OxfordPets"
train_files = "train.csv"
val_files = "test.csv"

In [6]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(12544*3, 1024),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = x.view(x.size(0), 12544*3)
        output = self.model(x)
        return output

In [7]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024,12544*3),            
            nn.Tanh(),
        )

    def forward(self, x):
        output = self.model(x)
        output = output.view(x.size(0), 3, 112, 112)
        return output
discriminator = Discriminator().to(device=device)
generator = Generator().to(device=device)

In [8]:
val_transforms, train_transforms = get_transforms(112)
train_data = MyDataset(os.path.join(dirdata,train_files), dirdata, train_transforms)
train_dataloader = DataLoader(train_data, batch_size=batch_size, num_workers=0, pin_memory=True, drop_last=True, shuffle=False, sampler=train_data.getSampler(name_clases = "multiclass"))


In [9]:
lr = 0.0001
num_epochs = 50
loss_function = nn.BCELoss()

optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=lr)
optimizer_generator = torch.optim.Adam(generator.parameters(), lr=lr)

In [11]:
for epoch in range(2):
    for real_samples, real_labels in train_dataloader:
        # Данные для тренировки дискриминатора
        real_samples = real_samples.to(device=device)
        real_samples_labels = torch.ones((batch_size, 1)).to(
            device=device)
        latent_space_samples = torch.randn((batch_size, 100)).to(
            device=device)
        generated_samples = generator(latent_space_samples)
        generated_samples_labels = torch.zeros((batch_size, 1)).to(
            device=device)
        all_samples = torch.cat((real_samples, generated_samples))
        all_samples_labels = torch.cat(
            (real_samples_labels, generated_samples_labels))

        # Обучение дискриминатора
        discriminator.zero_grad()
        output_discriminator = discriminator(all_samples)
        loss_discriminator = loss_function(
            output_discriminator, all_samples_labels)
        loss_discriminator.backward()
        optimizer_discriminator.step()

        # Данные для обучения генератора
        latent_space_samples = torch.randn((batch_size, 100)).to(
            device=device)

        # Обучение генератора
        generator.zero_grad()
        generated_samples = generator(latent_space_samples)
        output_discriminator_generated = discriminator(generated_samples)
        loss_generator = loss_function(
            output_discriminator_generated, real_samples_labels)
        loss_generator.backward()
        optimizer_generator.step()

        # Показываем loss
    print(f"Epoch: {epoch} Loss D.: {loss_discriminator}")
    print(f"Epoch: {epoch} Loss G.: {loss_generator}")

Epoch: 0 Loss D.: 0.27280429005622864
Epoch: 0 Loss G.: 0.9182949066162109
Epoch: 1 Loss D.: 0.06324335932731628
Epoch: 1 Loss G.: 3.218838691711426
