In [1]:
import torch, os
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import matplotlib.pyplot as plt
from tqdm import tqdm
from einops import rearrange, repeat

from torchvision.utils import save_image, make_grid
from torchvision import transforms
from torchvision.datasets import MNIST

from modules.networks.unet import UNetModel

os.environ["CUDA_VISIBLE_DEVICES"] = "1"  # Use only the first GPU
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

tf = transforms.Compose([transforms.ToTensor()]) # mnist is already normalised 0 to 1
train_dataset = MNIST("/data/edherron/data/MNIST", train=True, download=False, transform=tf)
# val_dataset = MNIST("/data/edherron/data/MNIST", train=False, download=False, transform=tf)

train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=256,
                                           shuffle=True,
                                           num_workers=1
                                           )

# val_loader = torch.utils.data.DataLoader(val_dataset, 
#                                         batch_size=256, 
#                                         shuffle=True, 
#                                         drop_last=True, 
#                                         num_workers=1
#                                         )

def pad(tensor):
    return repeat(tensor, 'b -> b 1 1 1')

In [2]:
class RectifiedFlow():
    def __init__(self, model=None, device=None, num_steps=1000):
        self.model = model
        self.model.to(device)
        self.N = num_steps
        self.device = device
        
    def get_train_tuple(self, z0=None, z1=None):
        t = torch.rand((z1.shape[0],)).to(self.device)
        z_t = pad(t) * z1 + (1. - pad(t)) * z0
        target = z1 - z0
        return z_t, t, target
    
    @torch.no_grad()
    def sample_ode(self, z0=None, N=None):
        if N is None:
            N = self.N
        dt = 1./N
        trajectory = []
        z = z0.detach().clone().to(self.device)
        
        
        trajectory.append(z.detach().clone())
        for i in range(N):
            t = torch.ones((z.shape[0],)) * i / N
            t = t.to(self.device)
            pred = self.model(z, t)
            z = z.detach().clone() + pred * dt
            
            trajectory.append(z.detach().clone())
        return trajectory

In [3]:
def train_rectified_flow(data_loader, rectified_flow, opt, device):
    rectified_flow.model.train()
    running_loss = 0.0
    for data in data_loader:
        z1, _ = data
        z1 = z1.to(device)
        # z1 = rearrange(z1.to(device), 'b c h w -> b (c h w)')
        z0 = torch.randn_like(z1).to(device)
        
        z_t, t, target = rectified_flow.get_train_tuple(z0, z1)
        
        pred = rectified_flow.model(z_t, t)
        
        loss = F.mse_loss(pred, target)
        
        loss.backward()
        opt.step()
        opt.zero_grad()
        running_loss += loss.item()
    avg_loss = running_loss / len(data_loader)
    return avg_loss

In [4]:
# Init all of our models
model = UNetModel()
RF = RectifiedFlow(model, device)

print("Number of parameters: ", sum(p.numel() for p in model.parameters()))

opt = torch.optim.Adam(model.parameters(), lr=1e-4)

for i in tqdm(range(50)):
    loss_rec = train_rectified_flow(train_loader, RF, opt, device)
    print('loss from epoch ', i, ': ', loss_rec)

Number of parameters:  3607873


  0%|          | 0/50 [00:00<?, ?it/s]

In [None]:
z = torch.randn((1,1,28,28))
trajectory = RF.sample_ode(z0 = z)

# Assuming `data_list` is your list of tensors
fig, axs = plt.subplots(1, 11, figsize=(20, 2))  # Adjust figsize as needed

for i, ax in enumerate(axs.flat):
    idx = i * 100  # Every 100th element
    if idx < len(trajectory):
        img = trajectory[idx].reshape(28, 28).detach().cpu().numpy()  # Reshape tensor to 28x28 for visualization
        ax.imshow(img, cmap='gray')  # Plot as grayscale image
        ax.set_title(f'Index {idx}')
        ax.axis('off')
    else:
        ax.axis('off')  # Hide axes for plots beyond the list length

plt.tight_layout()
plt.show()


In [None]:
z = torch.randn((10, 1, 28, 28))
trajectories = [RF.sample_ode(z0=z[i].unsqueeze(0)) for i in range(10)]

# Set up the plot for 10 rows of trajectories, each with 11 time steps (same as the column convention)
fig, axs = plt.subplots(10, 11, figsize=(20, 20))  # Adjust figsize as needed, ensure there's enough space

for row, trajectory in enumerate(trajectories):
    for col, ax in enumerate(axs[row]):
        idx = col * 100  # Every 100th element as in the original convention
        if idx < len(trajectory):
            img = trajectory[idx].reshape(28, 28).detach().cpu().numpy()  # Reshape tensor to 28x28 for visualization
            ax.imshow(img, cmap='gray')  # Plot as grayscale image
            ax.set_title(f'Index {idx}')
            ax.axis('off')
        else:
            ax.axis('off')  # Hide axes for plots beyond the list length

plt.tight_layout()
plt.show()

In [None]:
z = torch.randn((1, 1, 28, 28))
trajectory = RF.sample_ode(z0 = z, N = 10)  # This should give a trajectory with 11 tensors including the initial condition

# Assuming trajectory is your list of tensors with exactly 11 elements
fig, axs = plt.subplots(1, 11, figsize=(20, 2))  # Adjust figsize as needed

for idx, ax in enumerate(axs.flat):
    if idx < len(trajectory):
        img = trajectory[idx].reshape(28, 28).detach().cpu().numpy()  # Reshape tensor to 28x28 for visualization
        ax.imshow(img, cmap='gray')  # Plot as grayscale image
        ax.set_title(f'Step {idx}')
        ax.axis('off')
    else:
        ax.axis('off')  # Hide axes for plots beyond the list length

plt.tight_layout()
plt.show()

In [None]:
len(trajectory)