In [None]:
import torch
from torch import nn
import torchvision
from torchvision import models
from torch.nn import functional as F
from torchvision import transforms as T
from tqdm.notebook import tqdm

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"using device {device}")

using device cpu


In [None]:
class Generator(nn.Module):
  def __init__(self):
    super().__init__()
    self.layer_1a = nn.Linear(10, 200)
    self.layer_1b = nn.Linear(200, 2000)
    self.model = nn.Sequential(
        nn.Linear(2200, 2200),
        nn.ReLU(),
        nn.Linear(2200, 2200),
        nn.ReLU(),
        nn.Linear(2200, 2200),
        nn.ReLU(),
        nn.Linear(2200, 28*28),
        nn.Tanh(),
    )

  def forward(self, x):
    layer1a = F.relu(self.layer_1a(x[:, :10]))
    layer1b = F.relu(self.layer_1b(x[:, 10:]))
    x = torch.cat((layer1a, layer1b), dim=1)
    x = self.model(x)
    return x.view(x.shape[0], 28, 28)

In [None]:
generator = Generator()

In [None]:
discriminator = models.resnet18()
discriminator.conv1

Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

In [None]:
discriminator.conv1 = nn.Conv2d(1, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False)

In [None]:
discriminator.fc

Linear(in_features=512, out_features=1000, bias=True)

In [None]:
discriminator.fc.in_features

512

In [None]:
discriminator.fc = nn.Sequential(
    nn.Linear(discriminator.fc.in_features, 10),
    nn.Sigmoid(),)

In [None]:
d_optimizer = torch.optim.Adam(params=discriminator.parameters(), lr=.001, betas=(.5, .999))

In [None]:
loss_fn = nn.CrossEntropyLoss()

In [None]:
transform = T.Compose([
  T.ToTensor(),
  T.Resize(size=(224, 224)),
])
train_data = torchvision.datasets.MNIST(root='.', train=True, download=True, transform=transform)
test_data = torchvision.datasets.MNIST(root='.', train=False, download=True, transform=transform)

In [None]:
type(test_data)

torchvision.datasets.mnist.MNIST

In [None]:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=100, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=100, shuffle=False)

In [None]:
def train(model, loader, optimizer, loss_fn, device):
  running_loss = 0
  count = len(loader)
  for data, labels in tqdm(loader, total=count, desc="training"):
    model.zero_grad()
    labels = labels.to(device)
    data = data.to(device)
    out = model(data)
    loss = loss_fn(out, labels)
    loss.backward()
    running_loss += loss.item()

  avg_loss = running_loss / len(loader)
  print(f"training loss: {avg_loss}")
  return avg_loss

In [None]:
epochs = 2
training_losses = []
discriminator = discriminator.to(device)
for epoch in range(epochs):
  training_loss = train(discriminator, train_loader, d_optimizer, loss_fn, device)

training:   0%|          | 0/600 [00:00<?, ?it/s]



In [None]:
for batch in train_loader:
  break

In [None]:
labels, data = batch

In [None]:
data