In [1]:
import zipfile
import random
import os
import cv2
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision.utils import save_image
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
from torchvision.utils import save_image
import matplotlib.pyplot as plt



In [None]:
!pip install denoising-diffusion-pytorch torchvision


### Load data

In [3]:
import zipfile

zip_path = "minefree-class-split-wo-borders.zip"

extract_to = "data"

with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_to)

print("Extraction complete!")


Extraction complete!


### Diffusion models training


In [15]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import math

In [21]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [22]:
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
])

In [23]:
data_paths = {
    "bombed": "data/minefree-class-split-wo-borders/train/bombed",
    "not_bombed": "data/minefree-class-split-wo-borders/train/not bombed"
}

In [24]:
def load_dataloader(path):
    dataset = datasets.ImageFolder(root=os.path.dirname(path), transform=transform)
    class_name = os.path.basename(path)
    class_idx = dataset.class_to_idx[class_name]
    indices = [i for i, (_, y) in enumerate(dataset.samples) if y == class_idx]
    subset = torch.utils.data.Subset(dataset, indices)
    return DataLoader(subset, batch_size=16, shuffle=True)

In [25]:
bombed_loader = load_dataloader(data_paths["bombed"])
not_bombed_loader = load_dataloader(data_paths["not_bombed"])

In [26]:
def get_timestep_embedding(timesteps, embedding_dim=128):
    half_dim = embedding_dim // 2
    emb = math.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, device=timesteps.device) * -emb)
    emb = timesteps[:, None] * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    return emb

In [27]:
class UNet(nn.Module):
    def __init__(self, time_dim=256):
        super().__init__()
        self.time_mlp = nn.Sequential(
            nn.Linear(1, time_dim),
            nn.ReLU(),
            nn.Linear(time_dim, time_dim)
        )

        self.encoder1 = nn.Conv2d(3, 64, 3, padding=1)
        self.down1 = nn.Conv2d(64, 128, 3, padding=1)
        self.encoder2 = nn.Conv2d(128, 256, 3, padding=1)
        self.down2 = nn.Conv2d(256, 256, 3, padding=1)

        self.time_embed1 = nn.Linear(time_dim, 128)
        self.time_embed2 = nn.Linear(time_dim, 256)

        self.up1 = nn.ConvTranspose2d(256, 128, 3, padding=1)
        self.decoder1 = nn.Conv2d(128, 64, 3, padding=1)
        self.out = nn.Conv2d(64, 3, 1)

    def forward(self, x, t):
        t = t.float().unsqueeze(-1) / 1000
        t_embed = self.time_mlp(t)

        e1 = F.relu(self.encoder1(x))
        d1 = F.relu(self.down1(e1))

        t1 = self.time_embed1(t_embed).unsqueeze(-1).unsqueeze(-1)
        e2 = F.relu(self.encoder2(d1 + t1))

        d2 = F.relu(self.down2(e2))
        t2 = self.time_embed2(t_embed).unsqueeze(-1).unsqueeze(-1)
        d2 = d2 + t2

        u1 = F.relu(self.up1(d2))
        d1 = F.interpolate(d1, size=u1.shape[2:])
        u1 = u1 + d1

        out = F.relu(self.decoder1(u1))
        return self.out(out)

In [35]:
T = 1_000
betas = torch.linspace(1e-4, 0.02, T).to(device)
alphas = 1. - betas
alpha_bars = torch.cumprod(alphas, dim=0)

In [37]:
def train(model, dataloader, epochs=500, label="bombed"):
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    for epoch in range(epochs):
        for images, _ in dataloader:
            images = images.to(device)
            optimizer.zero_grad()

            t = torch.randint(0, T, (images.size(0),), device=device)
            alpha_bar_t = alpha_bars[t].view(-1, 1, 1, 1).to(device)

            noise = torch.randn_like(images)
            noisy_images = torch.sqrt(alpha_bar_t) * images + torch.sqrt(1 - alpha_bar_t) * noise

            predicted_noise = model(noisy_images, t)
            loss = F.mse_loss(predicted_noise, noise)
            loss.backward()
            optimizer.step()

        print(f"[{label}] Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")

In [38]:
@torch.no_grad()
def sample(model, img_size=(3, 64, 64), steps=T):
    model.eval()
    x = torch.randn((1, *img_size)).to(device)
    for t in reversed(range(steps)):
        t_batch = torch.full((1,), t, device=device, dtype=torch.long)
        beta_t = betas[t]
        alpha_t = alphas[t]
        alpha_bar_t = alpha_bars[t]

        predicted_noise = model(x, t_batch)
        noise = torch.randn_like(x) if t > 0 else 0

        x = (1 / torch.sqrt(alpha_t)) * (x - ((1 - alpha_t) / torch.sqrt(1 - alpha_bar_t)) * predicted_noise) + torch.sqrt(beta_t) * noise
    return x

In [39]:
bombed_model = UNet()
not_bombed_model = UNet()

In [40]:
train(bombed_model, bombed_loader, label="bombed")
train(not_bombed_model, not_bombed_loader, label="not_bombed")


[bombed] Epoch 1/500, Loss: 0.6294
[bombed] Epoch 2/500, Loss: 0.2431
[bombed] Epoch 3/500, Loss: 0.4169
[bombed] Epoch 4/500, Loss: 0.4090
[bombed] Epoch 5/500, Loss: 0.0572
[bombed] Epoch 6/500, Loss: 0.0568
[bombed] Epoch 7/500, Loss: 0.0401
[bombed] Epoch 8/500, Loss: 0.0375
[bombed] Epoch 9/500, Loss: 0.0715
[bombed] Epoch 10/500, Loss: 0.3121
[bombed] Epoch 11/500, Loss: 0.0354
[bombed] Epoch 12/500, Loss: 0.0303
[bombed] Epoch 13/500, Loss: 0.0285
[bombed] Epoch 14/500, Loss: 0.0261
[bombed] Epoch 15/500, Loss: 0.0361
[bombed] Epoch 16/500, Loss: 0.0249
[bombed] Epoch 17/500, Loss: 0.0236
[bombed] Epoch 18/500, Loss: 0.0263
[bombed] Epoch 19/500, Loss: 0.0228
[bombed] Epoch 20/500, Loss: 0.0267
[bombed] Epoch 21/500, Loss: 0.0204
[bombed] Epoch 22/500, Loss: 0.0219
[bombed] Epoch 23/500, Loss: 0.0179
[bombed] Epoch 24/500, Loss: 0.4440
[bombed] Epoch 25/500, Loss: 0.0527
[bombed] Epoch 26/500, Loss: 0.0173
[bombed] Epoch 27/500, Loss: 0.0168
[bombed] Epoch 28/500, Loss: 0.0143
[

In [41]:
from torchvision.utils import save_image

os.makedirs("samples_bombed", exist_ok=True)

os.makedirs("samples_not_bombed", exist_ok=True)

for i in range(10):
    img = sample(bombed_model)
    save_image(img, f"samples_bombed/sample_{i+1}.png")

for i in range(10):
    img = sample(not_bombed_model)
    save_image(img, f"samples_not_bombed/sample_{i+1}.png")

print("done!")


done!
