In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import random

In [None]:
from lstm import LSTM
from ddpm import DDPM, ContextUnet

In [None]:
def create_sine_dataset(samples=100, seq_len=50):
    data = []
    for _ in range(samples):
        # Random frequency and phase to make it robust
        freq = np.random.uniform(0.5, 2.0)
        phase = np.random.uniform(0, 2*np.pi)
        t = np.linspace(0, 10, seq_len)
        
        # Signal = Sine + slight noise
        signal = np.sin(freq * t + phase) + 0.05 * np.random.randn(seq_len)
        data.append(signal)
    
    # Shape: [Batch, Seq_Len, Input_Size=1]
    return torch.FloatTensor(np.array(data)).unsqueeze(2)

def create_shapes_dataset(samples=100, size=32):
    data = []
    for _ in range(samples):
        img = np.zeros((size, size))
        
        # Random Square
        x = np.random.randint(5, size-10)
        y = np.random.randint(5, size-10)
        w = np.random.randint(5, 10)
        img[x:x+w, y:y+w] = 1.0
        
        data.append(img)
        
    # Shape: [Batch, Channels=1, Height, Width]
    return torch.FloatTensor(np.array(data)).unsqueeze(1)

In [None]:
samples = create_sine_dataset(1000,50)
samples.size()

In [None]:
samples_shape = create_shapes_dataset(100,32)
samples_shape.size()

In [None]:
for _ in range(4):
    rand_index = random.randint(0, 100)
    plt.plot(samples[rand_index,:,0])

In [None]:
def train_lstm(model, data, epochs=50):
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    criterion = nn.MSELoss()
    
    print("--- Training LSTM ---")
    model.train()
    
    losses = []
    
    for epoch in range(epochs):
        optimizer.zero_grad()

        #print(data[0, :, :].size())
        predictions, _ = model(data)
        
        # Target: We want the model to predict the NEXT step.
        # Input at t should predict Data at t+1.
        # We crop the last prediction and the first data point to align them.
        preds_shifted = predictions[:, :-1, :]
        targets_shifted = data[:, 1:, :]
        
        loss = criterion(preds_shifted, targets_shifted)
        
        loss.backward()
        optimizer.step()
        
        losses.append(loss.item())
        if epoch % 10 == 0:
            print(f"Epoch {epoch}: Loss {loss.item():.5f}")
            
    return losses

In [None]:
def generate_sequence(model, seed_data, future_steps=50):
    """
    Uses the trained model to generate future steps based on a seed sequence.
    
    Args:
        model: The trained LSTM model
        seed_data: Tensor of shape [1, Seq_Len, Input_Size] (a single sequence)
        future_steps: How many steps to generate
        
    Returns:
        generated_seq: Tensor of shape [1, Future_Steps, Input_Size]
    """
    model.eval()
    
    # Initialize internal state
    h = torch.zeros(1, model.hidden_size).to(seed_data.device)
    c = torch.zeros(1, model.hidden_size).to(seed_data.device)
    
    generated_values = []
    
    with torch.no_grad():
        # 1. Warm up the internal state (h, c) using the seed data
        seq_len = seed_data.size(1)
        for t in range(seq_len):
            x_t = seed_data[:, t, :] # Shape (1, Input_Size)
            h, c, _ = model.cell(x_t, h, c)
        
        # The model is now primed. The last 'h' contains the memory of the seed.
        # We make the first prediction.
        current_input = model.predictor(h) 
        generated_values.append(current_input)
        
        # 2. Autoregressive Generation Loop
        for _ in range(future_steps - 1):
            # Feed the LAST PREDICTION as the NEXT INPUT
            h, c, _ = model.cell(current_input, h, c)
            pred = model.predictor(h)
            
            generated_values.append(pred)
            current_input = pred # Update input for next step
            
    return torch.stack(generated_values, dim=1)

In [None]:
# 1. Create Data
lstm_data = create_sine_dataset(samples=1000, seq_len=50)

# 2. Train
lstm_model = LSTM(input_size=1, hidden_size=16)
loss = train_lstm(lstm_model, lstm_data, epochs=200)

# Use the first sample from our dataset as a "seed"
seed_sample = lstm_data[0:1, :, :] # Shape [1, 50, 1]

future_steps = 50
generated = generate_sequence(lstm_model, seed_sample, future_steps=future_steps)

plt.plot(loss)
plt.ylabel("loss")
plt.xlabel("epoch")

In [None]:
print(f"Seed shape: {seed_sample.shape}")
print(f"Generated shape: {generated.shape}")
print(f"Generated values: {generated[0, :, 0].numpy()}")

In [None]:
for _ in range(4):
    future_steps = 50
    generated_ = generate_sequence(lstm_model, seed_sample, future_steps=future_steps)
    plt.plot(generated_[0,:,0])
plt.plot(seed_sample[0,:,0])

In [None]:
lstm_out, history = lstm_model(lstm_data)
for _ in range(4):
    rand_index = random.randint(0,100)
    plt.plot(lstm_out[rand_index,:,0].detach().numpy())

In [None]:
history["forget"][0].size()

In [None]:
plt.plot(history["forget"][0][0].detach().numpy())

In [None]:
def train_ddpm(ddpm_model, data, epochs=50):
    optimizer = optim.Adam(ddpm_model.network.parameters(), lr=1e-3)
    criterion = nn.MSELoss()
    
    print("\n--- Training DDPM ---")
    ddpm_model.train()
    
    losses = []
    
    for epoch in range(epochs):
        avg_loss = 0
        # Simple batch processing (treating whole dataset as one batch for simplicity here)
        x0 = data 
        n = len(x0)
        
        optimizer.zero_grad()
        
        # 1. Sample random timesteps for each image in batch
        t = torch.randint(0, ddpm_model.n_steps, (n,))
        
        # 2. Generate random noise (The "Inhibitor" we want to predict)
        epsilon = torch.randn_like(x0)
        
        # 3. Add noise to image (Forward Diffusion)
        # Formula: x_t = sqrt(alpha_bar) * x0 + sqrt(1-alpha_bar) * epsilon
        a_bar = ddpm_model.alpha_bars[t].view(-1, 1, 1, 1)
        noisy_image = torch.sqrt(a_bar) * x0 + torch.sqrt(1 - a_bar) * epsilon
        
        # 4. Model attempts to predict the noise
        noise_pred = ddpm_model.network(noisy_image, t)
        
        # 5. Loss: How close was the predicted noise to the actual noise?
        loss = criterion(noise_pred, epsilon)
        
        loss.backward()
        optimizer.step()
        
        losses.append(loss.item())
        if epoch % 10 == 0:
            print(f"Epoch {epoch}: Loss {loss.item():.5f}")

    return losses

In [None]:
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid

def train_ddpm_on_mnist():
    # --- Hyperparameters ---
    n_epoch = 20 # Enough for MNIST digits to appear
    batch_size = 128
    n_T = 400 # Timesteps
    device = "cuda" if torch.cuda.is_available() else "cpu"
    lrate = 1e-4

    # --- Data Loading ---
    tf = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
    dataset = datasets.MNIST("./data", train=True, download=True, transform=tf)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)

    # --- Setup Schedules (The "Inhibitor" Physics) ---
    beta_1 = 1e-4
    beta_T = 0.02
    betas = torch.linspace(beta_1, beta_T, n_T + 1).to(device)
    alphas = 1 - betas
    alphas_bar = torch.cumprod(alphas, dim=0)

    # Pre-calculate standard DDPM constants to save compute
    ddpm_schedules = {
        "sqrtab": torch.sqrt(alphas_bar),
        "sqrtmab": torch.sqrt(1 - alphas_bar),
        "oneover_sqrta": 1 / torch.sqrt(alphas),
        "mab_over_sqrtmab": (1 - alphas) / torch.sqrt(1 - alphas_bar),
        "sqrt_beta_t": torch.sqrt(betas),
    }

    # --- Model Init ---
    model = ContextUnet(in_channels=1, n_feat=64).to(device)
    ddpm = DDPM(model, ddpm_schedules, n_T, device)
    optim = torch.optim.Adam(ddpm.parameters(), lr=lrate)

    # --- Training Loop ---
    print(f"Starts training on {device}...")
    
    for ep in range(n_epoch):
        ddpm.train()
        pbar = torch.optim.lr_scheduler
        loss_ema = None
        
        for x, _ in dataloader:
            optim.zero_grad()
            x = x.to(device)
            
            # DDPM Forward Pass
            noise_pred, noise = ddpm(x)
            
            # Loss: Activator Error
            loss = F.mse_loss(noise_pred, noise)
            loss.backward()
            optim.step()
            
            if loss_ema is None: loss_ema = loss.item()
            else: loss_ema = 0.95 * loss_ema + 0.05 * loss.item()

        print(f"Epoch {ep:02d} | Loss: {loss_ema:.4f}")
        
        # --- Visualization for sanity check ---
        if ep % 5 == 0 or ep == n_epoch - 1:
            ddpm.eval()
            with torch.no_grad():
                x_gen, _ = ddpm.sample(16, (1, 28, 28), device)
                grid = make_grid(x_gen * -1 + 1, nrow=4) # Invert colors for visibility
                plt.figure(figsize=(4,4))
                plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
                plt.axis('off')
                plt.title(f"Generated at Epoch {ep}")
                plt.show()

    return ddpm

In [None]:
trained_model = train_ddpm_on_mnist()
final_img, history = trained_model.sample(n_sample=1, size=(1, 28, 28), device="cpu")