In [1]:
import torch, os
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm import tqdm
from einops import repeat

from modules.networks.unet import UNetModel

os.environ["CUDA_VISIBLE_DEVICES"] = "1"  # Use only the first GPU
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

# replace with your own dataset and dataloader
# train_dataset = 

# train_loader = 

def pad(tensor):
    return repeat(tensor, 'b -> b 1 1 1')

In [2]:
class RectifiedFlow():
    def __init__(self, model=None, device=None, num_steps=10):
        self.device = device
        self.model = model.to(self.device)
        self.N = num_steps
        
    def get_train_tuple(self, x0=None, x1=None):
        # randomly sample timesteps for training - timesteps are analogous to 
        # points along the linear interpolation of x0 and x1.
        t = torch.rand((x0.shape[0])).to(self.device)
        t = F.sigmoid(t)
        
        # find interpolated x i.e., x_t
        x_t = pad(t) * x1 + (1. - pad(t)) * x0
        
        # find our ground truth target value (velocity) we want our network to
        # approximate. This velocity term is the time derivative of the linear 
        # interpolation above. ie dX_t/dt = d(t*x1 + (1-t)*x0)/dt
        velocity = x1 - x0
        
        return x_t, t, velocity
    
    def rectified_flow_loss(self, x0, x1):
        '''
        Loss function for rectified flow model.

        x0: input tensor of shape (batch_size, channels, height, width) Real Images
        x1: input tensor of shape (batch_size, channels, height, width) Sim Images
        
        output: loss value we will optimize params of self.model with.
        '''
        # initialize x0 and x1 and send to device
        x0 = x0.to(self.device)
        x1 = x1.to(self.device)
        
        # get inputs (x_t and t) for network and velocity value for loss function.
        xt, t, velocity = self.get_train_tuple(x0, x1)
        
        # make velocity prediction with network
        velocity_hat = self.model(xt, t)
        
        # compute loss between prediction and velocity and return
        return F.mse_loss(velocity_hat, velocity)
        
    @torch.no_grad()
    def sample_ode(self, x0=None, N=None):
        # initialize number of timesteps in ode solver
        if N is None:
            N = self.N
            
        # initialize delta t
        dt = 1./N
        
        # initialize x for solver
        x = x0.detach().clone().to(self.device)
        
        # Euler method integration scheme
        for i in range(N):
            # init timesteps and send to device
            t = torch.ones((x0.shape[0])) * i / N
            t = t.to(self.device)
            
            #make velocity prediction
            velocity = self.model(x, t)
            
            #update x_t+1
            x = x.detach().clone() + velocity * dt
            
        return x

In [6]:
def train_rectified_flow(data_loader, rectified_flow, opt):
    rectified_flow.model.train()
    running_loss = 0.0
    for data in data_loader:
        x0, x1 = data
        loss = rectified_flow.rectified_flow_loss(x1, x0)
        
        loss.backward()
        opt.step()
        opt.zero_grad()
        running_loss += loss.item()
    avg_loss = running_loss / len(data_loader)
    return avg_loss

In [7]:
# Init all of our models
model = UNetModel()
RF = RectifiedFlow(model, device)

print("Number of parameters: ", sum(p.numel() for p in model.parameters()))

opt = torch.optim.Adam(model.parameters(), lr=1e-4)

for i in tqdm(range(2)):
    loss_rec = train_rectified_flow(train_loader, RF, opt)
    print('loss from epoch ', i, ': ', loss_rec)
    

Number of parameters:  3607873


 50%|█████     | 1/2 [00:31<00:31, 31.13s/it]

loss from epoch  0 :  0.08503996425486625


100%|██████████| 2/2 [01:01<00:00, 30.80s/it]

loss from epoch  1 :  0.08467258681008157



