In [1]:
import numpy as np
import random as random

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import matplotlib.pyplot as plt

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
img_shape = (1, 28, 28)
batch_size = 64
num_classes = 10
lr = 1e-4
beta1 = 0.0
alpha = 0.1 
epochs = 10

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.CelebA(root="/work/home/maben/project/blue_whale_lab/projects/pareto_ebm/project/datasets/", split="train", transform=transform, target_type="attr" ,download=False)
val_dataset = datasets.CelebA(root="/work/home/maben/project/blue_whale_lab/projects/pareto_ebm/project/datasets/", split="valid", transform=transform, target_type="attr" ,download=False)
test_dataset = datasets.CelebA(root="/work/home/maben/project/blue_whale_lab/projects/pareto_ebm/project/datasets/", split="test", transform=transform, target_type="attr" ,download=False)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [4]:
class Swish(nn.Module):

    def forward(self, x):
        return x * torch.sigmoid(x)

In [5]:
class JEMClassifier(nn.Module):
    def __init__(self, hidden_features=32):
        super().__init__()
        c_hid1 = hidden_features // 2
        c_hid2 = hidden_features
        c_hid3 = hidden_features * 2

        self.cnn_layers = nn.Sequential(
            nn.Conv2d(1, c_hid1, kernel_size=5, stride=2, padding=4),  # [16x16]
            Swish(),
            nn.Conv2d(c_hid1, c_hid2, kernel_size=3, stride=2, padding=1),  # [8x8]
            Swish(),
            nn.Conv2d(c_hid2, c_hid3, kernel_size=3, stride=2, padding=1),  # [4x4]
            Swish(),
            nn.Conv2d(c_hid3, c_hid3, kernel_size=3, stride=2, padding=1),  # [2x2]
            Swish(),
            nn.Flatten()
        )
        self.fc_energy = nn.Linear(c_hid3 * 4, 1) 
        self.fc_class = nn.Linear(c_hid3 * 4, num_classes)

    def forward(self, x):
        features = self.cnn_layers(x)
        energy = self.fc_energy(features).squeeze(-1)
        logits = self.fc_class(features)
        return energy, logits