# エネルギーベースモデル

In [1]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torchvision
from torchvision.datasets import MNIST
from torchvision import transforms
from IPython.display import display

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('mps')
device

device(type='mps')

In [2]:
batch_size = 64

dataset = MNIST(
    root="data/",
    train=True,
    download=True,
    transform=transforms.ToTensor()
)

dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True
)

sample_x, _ = next(iter(dataloader))
w, h = sample_x.shape[2:]
image_size = w * h
print("batch shape:", sample_x.shape)
print("width:", w)
print("height:", h)
print("image size:", image_size)

batch shape: torch.Size([64, 1, 28, 28])
width: 28
height: 28
image size: 784


In [3]:
class EnergyModel(nn.Module):
    def __init__(self, image_size=image_size):
        super().__init__()
        self.image_size = image_size
        self.net = nn.Sequential(
            nn.Linear(image_size, 512),
            nn.ReLU(),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        x = x.view(-1, self.image_size)
        return self.net(x)

In [25]:
def generate(model, n_iter=100, n_images=batch_size):
    model.eval()
    z = torch.randn(n_images, image_size, device=device)
    z.requires_grad = True
    optimizer = optim.Adam([z], lr=1e-3)
    for _ in range(n_iter):
        print(_)
        z = F.sigmoid(z)
        optimizer.zero_grad()
        energy = model(z).mean()
        (-energy).backward()
        optimizer.step()
    x = F.sigmoid(z)
    return x.detach()

In [26]:
def draw(model, n_rows=1, n_cols=8, size=64):
    images = generate(model, n_images=n_rows*n_cols)
    images = transforms.Resize(size)(images)
    img = torchvision.utils.make_grid(images, n_cols)
    img = transforms.functional.to_pil_image(img)
    display(img)

In [27]:
def train(model, optim, n_epochs):
    for n in range(n_epochs):
        losses = []
        for (x_real, _) in dataloader:
            x_real = x_real.to(device)
            x_fake = generate(model)

            out_real = model(x_real)
            out_fake = model(x_fake)
            loss = out_fake.mean() - out_real.mean()
            loss.backward()
            losses.append(loss.item())
            optim.step()
        print(f"{n}epoch loss: {sum(losses)/len(losses)}")
        

In [28]:
model = EnergyModel().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [29]:
draw(model)

0
1


RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.