In [1]:
# prepare data
# define model
# define training and validation
# perform experiment

# joint energy based model的两大功能：
# 1.discriminative ability：来自softmax
# 2.generation ability：来自对logit的reinterpret以及对其的约束
# performance gap 本质上是对logit进行正则化带来的performance gap

In [2]:
import numpy as np
import random as random

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
img_shape = (1, 28, 28)
batch_size = 64
num_classes = 10
lr = 1e-4
beta1 = 0.0
alpha = 0.1 
epochs = 10

In [4]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root="/work/home/maben/project/blue_whale_lab/projects/pareto_ebm/datasets", train=True, transform=transform, download=True)
val_dataset = datasets.MNIST(root="/work/home/maben/project/blue_whale_lab/projects/pareto_ebm/datasets", train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)


In [5]:
class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

class JEMClassifier(nn.Module):
    def __init__(self, hidden_features=32):
        super().__init__()
        c_hid1 = hidden_features // 2
        c_hid2 = hidden_features
        c_hid3 = hidden_features * 2

        self.cnn_layers = nn.Sequential(
            nn.Conv2d(1, c_hid1, kernel_size=5, stride=2, padding=4),  # [16x16]
            Swish(),
            nn.Conv2d(c_hid1, c_hid2, kernel_size=3, stride=2, padding=1),  # [8x8]
            Swish(),
            nn.Conv2d(c_hid2, c_hid3, kernel_size=3, stride=2, padding=1),  # [4x4]
            Swish(),
            nn.Conv2d(c_hid3, c_hid3, kernel_size=3, stride=2, padding=1),  # [2x2]
            Swish(),
            nn.Flatten()
        )
        self.fc_class = nn.Linear(c_hid3 * 4, num_classes)

    def forward(self, x):
        features = self.cnn_layers(x)
        logits = self.fc_class(features)
        energy = torch.log(torch.sum(torch.exp(logits),dim=-1))
        return energy, logits

In [16]:
class Sampler:
    def __init__(self, model, img_shape, sample_size, max_len=8192):
        super().__init__()
        self.model = model
        self.img_shape = img_shape
        self.sample_size = sample_size
        self.max_len = max_len
        self.examples = [(torch.rand((1,)+img_shape)*2-1) for _ in range(self.sample_size)]

    def sample_new_exmps(self, steps=60, step_size=10):
        device = next(self.model.parameters()).device 

        n_new = np.random.binomial(self.sample_size, 0.05)
        rand_imgs = torch.rand((n_new,) + self.img_shape, device=device) * 2 - 1
        old_imgs = torch.cat(random.choices(self.examples, k=self.sample_size - n_new), dim=0).to(device)
        inp_imgs = torch.cat([rand_imgs, old_imgs], dim=0).detach()

        inp_imgs, energy_list = self.generate_samples(self.model, inp_imgs, steps=steps, step_size=step_size)

        self.examples = list(inp_imgs.cpu().chunk(self.sample_size, dim=0)) + self.examples
        self.examples = self.examples[:self.max_len]
        return inp_imgs, energy_list

    def conditional_sample(self, target_class, steps=60, step_size=10, num_samples=100):
        """条件采样：生成指定类别的样本（基于p(y|x)最大值）"""
        device = next(self.model.parameters()).device
        
        # 生成大量候选样本
        candidate_imgs = torch.rand((num_samples,) + self.img_shape, device=device) * 2 - 1
        
        # 条件采样
        candidate_imgs = self.generate_conditional_samples(
            self.model, candidate_imgs, target_class, steps=steps, step_size=step_size
        )
        
        # 计算所有候选样本的p(y|x)值
        with torch.no_grad():
            _, logits = self.model(candidate_imgs)
            # 计算p(y|x) = softmax(logits)
            p_y_given_x = F.softmax(logits, dim=1)
            # 获取目标类别的概率
            target_probs = p_y_given_x[:, target_class]
        
        # 选择p(y|x)值最高的前10%样本
        num_top_samples = max(1, int(0.1 * num_samples))  # 至少保留1个样本
        _, top_indices = torch.topk(target_probs, num_top_samples)
        selected_imgs = candidate_imgs[top_indices]
        
        return selected_imgs

    def generate_samples(self,model, inp_imgs, steps=60, step_size=10, return_img_per_step=False):
        device = inp_imgs.device
        is_training = model.training
        model.eval()

        for p in model.parameters():
            p.requires_grad = False
        inp_imgs.requires_grad = True

        had_gradients_enabled = torch.is_grad_enabled()
        torch.set_grad_enabled(True)

        noise = torch.randn(inp_imgs.shape, device=device) 

        imgs_per_step = []
        energy_list = []

        for _ in range(steps):
            noise.normal_(0, 0.005)
            inp_imgs.data.add_(noise.data)
            inp_imgs.data.clamp_(min=-1.0, max=1.0)

            out_imgs = -model(inp_imgs)[0]
            out_imgs.sum().backward()
            energy_list.append(out_imgs.detach().to('cpu').numpy())
            inp_imgs.grad.data.clamp_(-0.03, 0.03)

            inp_imgs.data.add_(-step_size * inp_imgs.grad.data)
            inp_imgs.grad.detach_()
            inp_imgs.grad.zero_()
            inp_imgs.data.clamp_(min=-1.0, max=1.0)

            if return_img_per_step:
                imgs_per_step.append(inp_imgs.clone().detach())

        for p in model.parameters():
            p.requires_grad = True
        model.train(is_training)
        torch.set_grad_enabled(had_gradients_enabled)

        if return_img_per_step:
            return torch.stack(imgs_per_step, dim=0), energy_list
        else:
            return inp_imgs, energy_list

    def generate_conditional_samples(self, model, inp_imgs, target_class, steps=60, step_size=10, return_img_per_step=False):
        """条件采样：生成指定类别的样本（基于p(y|x)最大值）"""
        device = inp_imgs.device
        is_training = model.training
        model.eval()

        for p in model.parameters():
            p.requires_grad = False
        inp_imgs.requires_grad = True

        had_gradients_enabled = torch.is_grad_enabled()
        torch.set_grad_enabled(True)

        noise = torch.randn(inp_imgs.shape, device=device) 

        imgs_per_step = []
        energy_list = []
        for _ in range(steps):
            noise.normal_(0, 0.005)
            inp_imgs.data.add_(noise.data)
            inp_imgs.data.clamp_(min=-1.0, max=1.0)

            # 获取能量和logits
            energy, logits = model(inp_imgs)
            energy_list.append(energy.detach().to('cpu').numpy())
            # 计算p(y|x) = softmax(logits)
            p_y_given_x = F.softmax(logits, dim=1)
            
            # 条件能量：使用目标类别的p(y|x)值作为能量
            # 我们要最大化p(y|x)，所以使用负号（因为我们要最小化能量）
            conditional_energy = -p_y_given_x[:, target_class]
            
            # 反向传播
            conditional_energy.sum().backward()
            inp_imgs.grad.data.clamp_(-0.03, 0.03)

            inp_imgs.data.add_(-step_size * inp_imgs.grad.data)
            inp_imgs.grad.detach_()
            inp_imgs.grad.zero_()
            inp_imgs.data.clamp_(min=-1.0, max=1.0)

            if return_img_per_step:
                imgs_per_step.append(inp_imgs.clone().detach())

        for p in model.parameters():
            p.requires_grad = True
        model.train(is_training)
        torch.set_grad_enabled(had_gradients_enabled)

        if return_img_per_step:
            return torch.stack(imgs_per_step, dim=0), energy_list
        else:
            return inp_imgs, energy_list

In [7]:
def jem_loss(energy_real, energy_fake, logits, labels, alpha):
    min_batch_size = min(energy_real.size(0), energy_fake.size(0))
    energy_real = energy_real[:min_batch_size]
    energy_fake = energy_fake[:min_batch_size]

    reg_loss = alpha * (energy_real ** 2 + energy_fake ** 2).mean()
    cdiv_loss = energy_fake.mean() - energy_real.mean()

    class_loss = nn.CrossEntropyLoss()(logits, labels)

    total_loss = reg_loss + cdiv_loss + class_loss
    return total_loss, reg_loss, cdiv_loss, class_loss

In [8]:
model = JEMClassifier(hidden_features=32).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr, betas=(beta1, 0.999))
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.97)
sampler = Sampler(model, img_shape=(1, 28, 28), sample_size=batch_size)

In [9]:
for epoch in range(epochs):
    model.train()
    train_loss, reg_loss, cdiv_loss, class_loss = 0, 0, 0, 0

    for real_imgs, labels in train_loader:
        real_imgs, labels = real_imgs.to(device), labels.to(device)

        small_noise = torch.randn_like(real_imgs) * 0.005
        real_imgs = real_imgs + small_noise
        real_imgs = real_imgs.clamp(min=-1.0, max=1.0)

        fake_imgs = sampler.sample_new_exmps(steps=60, step_size=10)

        energy_real, logits_real = model(real_imgs)
        energy_fake, _ = model(fake_imgs)

        loss, r_loss, c_loss, cls_loss = jem_loss(
            energy_real, energy_fake, logits_real, labels, alpha
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        reg_loss += r_loss.item()
        cdiv_loss += c_loss.item()
        class_loss += cls_loss.item()

    scheduler.step()

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for val_imgs, val_labels in val_loader:
            val_imgs, val_labels = val_imgs.to(device), val_labels.to(device)
            _, logits_val = model(val_imgs)
            predictions = torch.argmax(logits_val, dim=1)
            correct += (predictions == val_labels).sum().item()
            total += val_labels.size(0)

    accuracy = correct / total
    print(f"Epoch {epoch + 1}/{epochs}")
    print(f"Train Loss: {train_loss / len(train_loader):.4f}, Reg Loss: {reg_loss / len(train_loader):.4f}, "
          f"CDiv Loss: {cdiv_loss / len(train_loader):.4f}, Class Loss: {class_loss / len(train_loader):.4f}")
    print(f"Validation Accuracy: {accuracy:.4f}")

Epoch 1/10
Train Loss: 1.9587, Reg Loss: 0.1071, CDiv Loss: 0.1156, Class Loss: 1.7360
Validation Accuracy: 0.7574
Epoch 2/10
Train Loss: 1.0274, Reg Loss: 0.0176, CDiv Loss: 0.1917, Class Loss: 0.8181
Validation Accuracy: 0.8380
Epoch 3/10
Train Loss: 0.7902, Reg Loss: 0.0144, CDiv Loss: 0.1871, Class Loss: 0.5887
Validation Accuracy: 0.8814
Epoch 4/10
Train Loss: 0.6679, Reg Loss: 0.0131, CDiv Loss: 0.1771, Class Loss: 0.4777
Validation Accuracy: 0.9005
Epoch 5/10
Train Loss: 0.5865, Reg Loss: 0.0127, CDiv Loss: 0.1656, Class Loss: 0.4083
Validation Accuracy: 0.9093
Epoch 6/10
Train Loss: 0.5235, Reg Loss: 0.0123, CDiv Loss: 0.1546, Class Loss: 0.3565
Validation Accuracy: 0.9237
Epoch 7/10
Train Loss: 0.4720, Reg Loss: 0.0120, CDiv Loss: 0.1458, Class Loss: 0.3142
Validation Accuracy: 0.9333
Epoch 8/10
Train Loss: 0.4287, Reg Loss: 0.0117, CDiv Loss: 0.1359, Class Loss: 0.2810
Validation Accuracy: 0.9366
Epoch 9/10
Train Loss: 0.3948, Reg Loss: 0.0107, CDiv Loss: 0.1283, Class Loss: 

In [40]:
def generate_images_with_sampler(model, sampler, num_images=16, steps=60, step_size=10):
    model.eval()
    with torch.no_grad():
        generated_imgs, energy_list = sampler.sample_new_exmps(steps=steps, step_size=step_size)
        generated_imgs = generated_imgs[:num_images].detach() 
        print(energy_list)
        energy_list = energy_list.squeeze()[:num_images]
        print(energy_list)
    return generated_imgs.cpu(), energy_list

def generate_conditional_images_with_sampler(model, sampler, target_class, num_images=16, steps=60, step_size=10):
    model.eval()
    with torch.no_grad():
        generated_imgs, energy_list = sampler.conditional_sample(target_class=target_class, steps=steps, step_size=step_size, num_samples=num_images)
        generated_imgs = generated_imgs[:num_images].detach() 
        energy_list = energy_list.squeeze()[:num_images]
    return generated_imgs.cpu(), energy_list

def plot_generated_images(images,energy_list, num_cols=4):
    num_images = images.size(0)
    num_rows = (num_images + num_cols - 1) // num_cols

    images = (images + 1) / 2 

    fig, axes = plt.subplots(num_rows, num_cols, figsize=(num_cols * 2, num_rows * 2))
    axes = axes.flatten()
    print(energy_list[0])
    for i, ax in enumerate(axes):
        if i < num_images:
            img = images[i].squeeze(0).numpy()
            #ax.set_title(f"energy: {energy_list[i]:.2f}")
            ax.imshow(img, cmap="gray")
            ax.axis("off")
        else:
            ax.axis("off")
    plt.tight_layout()
    plt.show()


In [41]:
num_images = 16
generated_images,energy_list = generate_images_with_sampler(model, sampler, num_images=num_images, steps=2000, step_size=0.1)

plot_generated_images(generated_images,energy_list)

[array([5.565112  , 5.4121475 , 5.6369705 , 5.701155  , 1.3122103 ,
       0.60476726, 0.28766826, 0.2598699 , 0.37230647, 0.3600864 ,
       0.6271051 , 0.45762604, 0.20580508, 0.11808214, 5.598494  ,
       0.7900234 , 5.607905  , 5.5186286 , 0.1015663 , 0.3807166 ,
       0.1903567 , 0.3258131 , 0.38311183, 1.1647314 , 1.1502368 ,
       1.0379624 , 0.26200175, 0.13962284, 0.5833627 , 0.02560124,
       0.231122  , 0.32462534, 0.498031  , 1.2037928 , 0.21093306,
       1.0994397 , 0.07756635, 0.21070403, 0.41354343, 0.4132673 ,
       0.22561759, 5.6425633 , 5.6507425 , 1.239773  , 1.4802846 ,
       0.5135548 , 1.2503045 , 0.21956792, 5.6574473 , 1.3184294 ,
       5.521219  , 1.1515639 , 0.10610011, 5.619008  , 0.9171724 ,
       1.1710314 , 1.248094  , 1.2439238 , 1.4640326 , 5.690794  ,
       0.22976536, 0.31273893, 0.5146125 , 0.23208325], dtype=float32), array([5.5612717 , 5.40656   , 5.6347146 , 5.697483  , 1.3105339 ,
       0.6058638 , 0.2871025 , 0.2594149 , 0.37066364, 0

AttributeError: 'list' object has no attribute 'squeeze'

In [31]:
num_images = 16
generated_images,energy_list = generate_conditional_images_with_sampler(model, sampler, target_class=0, num_images=num_images, steps=2000, step_size=0.1)

plot_generated_images(generated_images,energy_list)

TypeError: conv2d() received an invalid combination of arguments - got (tuple, Parameter, Parameter, tuple, tuple, tuple, int), but expected one of:
 * (Tensor input, Tensor weight, Tensor bias = None, tuple of ints stride = 1, tuple of ints padding = 0, tuple of ints dilation = 1, int groups = 1)
      didn't match because some of the arguments have invalid types: (!tuple of (Tensor, list)!, !Parameter!, !Parameter!, !tuple of (int, int)!, !tuple of (int, int)!, !tuple of (int, int)!, !int!)
 * (Tensor input, Tensor weight, Tensor bias = None, tuple of ints stride = 1, str padding = "valid", tuple of ints dilation = 1, int groups = 1)
      didn't match because some of the arguments have invalid types: (!tuple of (Tensor, list)!, !Parameter!, !Parameter!, !tuple of (int, int)!, !tuple of (int, int)!, !tuple of (int, int)!, !int!)
