In [1]:
import torch, os
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import matplotlib.pyplot as plt
from tqdm import tqdm
from einops import rearrange, repeat

from torchvision.utils import save_image, make_grid
from torchvision import transforms
from torchvision.datasets import MNIST

from modules.networks.Unet import ContextUnet

os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # Use only the first GPU
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

tf = transforms.Compose([transforms.ToTensor()]) # mnist is already normalised 0 to 1
val_dataset = MNIST("/data/edherron/data/MNIST", train=False, download=False, transform=tf)

val_loader = torch.utils.data.DataLoader(val_dataset, 
                                        batch_size=256, 
                                        shuffle=True, 
                                        drop_last=True, 
                                        num_workers=1
                                        )

    
class Encoder(nn.Module):
    def __init__(self, dim=32):
        super().__init__()
        # Initial convolution block
        layers = [nn.Conv2d(1, dim, 4), nn.InstanceNorm2d(dim * 2), nn.Tanh()]
        # Downsampling
        for _ in range(4):
            layers += [nn.Conv2d(dim, dim * 2, 4), nn.InstanceNorm2d(dim * 2), nn.Tanh()]
            dim *= 2
        # Final block with 1D convolution for demonstration purposes
        layers += [nn.Conv2d(dim, 6, 1), nn.InstanceNorm2d(1), nn.Tanh() ]
        self.model_blocks = nn.Sequential(*layers)

    def forward(self, x):
        x = self.model_blocks(x)
        return x

class Decoder(nn.Module):
    def __init__(self, dim=32):
        super().__init__()
        dim = dim * 2 ** 3
        layers = [nn.Conv2d(6,dim,1), nn.InstanceNorm2d(dim), nn.LeakyReLU(0.2, inplace=True)]
        # Upsampling
        for _ in range(5):
            layers += [nn.ConvTranspose2d(dim, dim // 2, 4), nn.InstanceNorm2d(dim // 2), nn.LeakyReLU(0.2, inplace=True)]
            dim = dim // 2
        # Output layer
        layers += [nn.Conv2d(dim, 1, 1)]
        self.model_blocks = nn.Sequential(*layers, nn.Sigmoid())

    def forward(self, x):
        x = self.model_blocks(x)
        return x
    
class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.ts = nn.Linear(1, 6)
        
    def forward(self, x, t):
        latent = self.encoder(x)
        ts = repeat(self.ts(t), 'b v -> b v 1 1')
        eps = self.decoder(latent + ts)
        return eps
    
def pad(tensor):
    return repeat(tensor, 'b 1 -> b 1 1 1')

In [2]:
class RectifiedFlow():
    def __init__(self, model=None, device=None, num_steps=1000):
        self.model = Network()
        self.model.to(device)
        self.N = num_steps
        self.device = device
        
    def get_train_tuple(self, z0=None, z1=None):
        t = torch.rand((z1.shape[0],1)).to(self.device)
        z_t = pad(t) * z1 + (1. - pad(t)) * z0
        target = z1 - z0
        return z_t, t, target
    
    @torch.no_grad()
    def sample_ode(self, z0=None, N=None):
        if N is None:
            N = self.N
        dt = 1./N
        trajectory = []
        z = z0.detach().clone().to(self.device)
        
        
        trajectory.append(z.detach().clone())
        for i in range(N):
            t = torch.ones((z.shape[0],)) * i / N
            t = t.to(self.device)
            pred = self.model(z, t)
            z = z.detach().clone() + pred * dt
            
            trajectory.append(z.detach().clone())
        return trajectory

In [3]:
def train_rectified_flow(data_loader, rectified_flow, opt, device):
    rectified_flow.model.train()
    running_loss = 0.0
    for data in data_loader:
        z1, _ = data
        z1 = z1.to(device)
        # z1 = rearrange(z1.to(device), 'b c h w -> b (c h w)')
        z0 = torch.randn_like(z1).to(device)
        
        z_t, t, target = rectified_flow.get_train_tuple(z0, z1)
        
        pred = rectified_flow.model(z_t, t)
        
        loss = F.mse_loss(pred, target)
        
        loss.backward()
        opt.step()
        opt.zero_grad()
        running_loss += loss.item()
    avg_loss = running_loss / len(data_loader)
    return avg_loss

In [4]:
# Init all of our models
model = Network()
RF = RectifiedFlow(model, device)

print("Number of parameters: ", sum(p.numel() for p in model.parameters()))

opt = torch.optim.Adam(model.parameters(), lr=3e-4)

for i in tqdm(range(5)):
    loss_rec = train_rectified_flow(val_loader, RF, opt, device)
    print('loss from epoch ', i, ': ', loss_rec)

Number of parameters:  3490291


 20%|██        | 1/5 [00:05<00:20,  5.04s/it]

loss from epoch  0 :  1.3196749075865135


 40%|████      | 2/5 [00:09<00:14,  4.99s/it]

loss from epoch  1 :  1.3210092935806665


 60%|██████    | 3/5 [00:15<00:10,  5.07s/it]

loss from epoch  2 :  1.3206556362983508


 80%|████████  | 4/5 [00:20<00:05,  5.18s/it]

loss from epoch  3 :  1.321506808965634


100%|██████████| 5/5 [00:25<00:00,  5.20s/it]

loss from epoch  4 :  1.3215078971324823





In [None]:
z = torch.randn((1,28**2))
trajectory = RF.sample_ode(z0 = z)

In [None]:
print(len(trajectory))

In [None]:
print(len(trajectory))
print(type(trajectory[69]))
print(trajectory[69].shape)

In [None]:
# Assuming `data_list` is your list of tensors
fig, axs = plt.subplots(1, 11, figsize=(20, 2))  # Adjust figsize as needed

for i, ax in enumerate(axs.flat):
    idx = i * 100  # Every 100th element
    if idx < len(trajectory):
        img = trajectory[idx].reshape(28, 28).detach().cpu().numpy()  # Reshape tensor to 28x28 for visualization
        ax.imshow(img, cmap='gray')  # Plot as grayscale image
        ax.set_title(f'Index {idx}')
        ax.axis('off')
    else:
        ax.axis('off')  # Hide axes for plots beyond the list length

plt.tight_layout()
plt.show()


In [None]:
def train_loop(data_loader, ae, opt, device):
    ae.train()
    running_loss = 0.0
    for data in data_loader:
        images, _ = data
        images = rearrange(images.to(device), 'b c h w -> b (c h w)')
        recon = ae(images)
        loss = F.mse_loss(recon, images)
        loss.backward()
        opt.step()
        opt.zero_grad()
        running_loss += loss.item()
    avg_loss = running_loss / len(data_loader)
    return avg_loss

In [None]:
# Init all of our models
ae = MLP_AE()
ae.to(device)
print("Number of parameters: ", sum(p.numel() for p in ae.parameters()))

opt = torch.optim.Adam(ae.parameters(), lr=3e-4)

for i in tqdm(range(25)):
    loss_rec = train_loop(val_loader, ae, opt, device)
    print('loss from epoch ', i, ': ', loss_rec)

In [None]:
ae.eval()
image, label = val_dataset[737]
image = rearrange(image.to(device), 'b h w -> b (h w)')
pred = ae(image)
pred = rearrange(pred, 'b (h w) -> b h w', h=28, w=28)
image = rearrange(image, 'b (h w) -> b h w', h=28, w=28)

fig, (ax1, ax3) = plt.subplots(1, 2, figsize=(10, 5))

ax1.imshow(image.detach().cpu().squeeze().numpy())
ax3.imshow(pred.detach().cpu().squeeze().numpy())