In [1]:
import torch.nn as nn
import copy
from video_dataset import VideoDataset, load_data
from torch.utils.data import DataLoader
import torch.nn.functional as F
from tqdm import tqdm
from utils import trunc_normal_
import torch
import vision_transformer as vit
from load_model import load_models
import logging
import os
from video import generate_avi

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    filename='training.log',
    filemode='w'
)
logger = logging.getLogger(__name__)

In [2]:
encoder, predictor, action_conditioner, diffusion_model = load_models()

Loaded encoder
Number of encoder parameters: 113998080
Loaded predictor
Number of predictor parameters: 213282816
Loaded action_conditioner
Number of action parameters: 42696960
Loaded diffusion_model
Number of parameters: 141838848
Initializing encoder weights
Initializing predictor weights
Initializing action weights
Initializing diffusion weights


In [3]:
!nvidia-smi

Wed Apr 17 20:26:27 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.171.04             Driver Version: 535.171.04   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 3060        Off | 00000000:01:00.0  On |                  N/A |
|  0%   43C    P2              39W / 170W |   2605MiB / 12288MiB |     80%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [4]:
# # load checkpoints from checkpoints_xd/name_weights_6.pt
EPOCH = 155
encoder.load_state_dict(torch.load(f'./checkpoints/{EPOCH}/encoder_weights_{EPOCH}.pt'))
predictor.load_state_dict(torch.load(f'./checkpoints/{EPOCH}/predictor_weights_{EPOCH}.pt'))
action_conditioner.load_state_dict(torch.load(f'./checkpoints/{EPOCH}/action_conditioner_weights_{EPOCH}.pt'))
diffusion_model.load_state_dict(torch.load(f'./checkpoints/{EPOCH}/diffusion_model_weights_{EPOCH}.pt'))

<All keys matched successfully>

In [8]:
# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load optimizers/schedulers
optimizer = torch.optim.AdamW(list(encoder.parameters()) + list(predictor.parameters()) + list(action_conditioner.parameters()), lr=3e-6)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
diffusion_optimizer = torch.optim.Adam(diffusion_model.parameters(), lr=3e-6)

# Define the loss function
def l2_loss(predictions, targets):
    return torch.mean(torch.sqrt(torch.sum((predictions - targets) ** 2, dim=1)))

# Define a MSE loss for the diffusion model
def diffusion_loss(predictions, targets):
    return F.mse_loss(predictions, targets)

In [None]:
# Training loop
epochs = 500
ema_decay = 0.999  # EMA decay rate
ema_encoder = copy.deepcopy(encoder)  # Create a copy of the encoder for EMA
ema_encoder.to(device)
video_num = 2268

def update_ema(model, ema_model, decay):
    with torch.no_grad():
        for param, ema_param in zip(model.parameters(), ema_model.parameters()):
            ema_param.data.mul_(decay).add_(param.data, alpha=1 - decay)

accumulation_steps = 1
for epoch in range(EPOCH+1, epochs):
    # Load a new dataset
    n_videos = 25
    train_dataloader, test_dataloader = load_data(
        data_folder='./datas/find-cave/', 
        start_idx=epoch*n_videos, 
        n_videos=n_videos, 
        action_sequence_length=25, 
        split_ratio=0.75, 
        batch_size=4,
        num_workers=2,
        shuffle=True,
        frame_skip=1
    )

    # Use this training data 3 times for first 3, then 2 & increase n_videos from 3 to 5
    for i in range(1):
    
        batch_loss = 0
        for batch_idx, batch in enumerate(tqdm(train_dataloader)):
            frame_t, action_sequence, frame_tp1 = batch
            frame_t = frame_t.to(device)
            action_sequence = action_sequence.to(device)
            frame_tp1 = frame_tp1.to(device)

            encoder.train()
            predictor.train()
            action_conditioner.train()
            diffusion_model.train()

            # Forward pass for the JEPA model
            x = encoder(frame_t)
            x = action_conditioner(x, action_sequence)
            x = predictor(x)

            # Compute target using EMA encoder
            with torch.no_grad():
                y = ema_encoder(frame_tp1)
                y = F.layer_norm(y, (y.size(-1),))
            
            # Compute loss for prediction network
            loss_pred = l2_loss(x, y)
            pstd_pred = torch.std(x, dim=1)  # Predictor variance across patches
            loss_reg = torch.mean(F.relu(1. - pstd_pred))
            reg_coeff = 0.0002  # Regularization coefficient
            loss = loss_pred + reg_coeff * loss_reg

            # Forward pass for the diffusion model
            y = diffusion_model(y)
            y = y.view(-1, 3, 224, 224)

            # Compute loss for diffusion model
            diff_loss = diffusion_loss(y, frame_tp1)

            # Accumulate gradients
            loss = loss / accumulation_steps
            diff_loss = diff_loss / accumulation_steps
            loss.backward()
            diff_loss.backward()

            # Perform optimization step after accumulating gradients
            if (batch_idx + 1) % accumulation_steps == 0:
                optimizer.step()
                diffusion_optimizer.step()
                optimizer.zero_grad()
                diffusion_optimizer.zero_grad()
        
            # Update EMA encoder
            update_ema(encoder, ema_encoder, ema_decay)
            # Print loss for every 5th batch
            if batch_idx % 100 == 0:
                logger.info(f"Epoch [{epoch+1}/{epochs}], Batch [{batch_idx}/{len(train_dataloader)}], JEPA Loss: {loss.item() * accumulation_steps:.4f}")
                logger.info(f"Epoch [{epoch+1}/{epochs}], Batch [{batch_idx}/{len(train_dataloader)}], Diffusion Loss: {diff_loss.item() * accumulation_steps:.4f}")
            
            if batch_idx % 500 == 0 and video_num != 235:
                generate_avi(encoder, predictor, action_conditioner, diffusion_model, video_num, path=f'./output_video_{epoch}_{batch_idx}.avi')
                logger.info(f'new_vid at {video_num}')
                video_num += 1
    
        # Update learning rate
        scheduler.step()

    # Print progress
    if (epoch + 1) % 1 == 0:
        logger.info(f"Epoch [{epoch+1}/{epochs}], Train Loss: {loss.item() * accumulation_steps:.4f}")

    # Save checkpoint
    if (epoch + 1) % 1 == 0:
        print("saving checkpoints")
        if not os.path.exists(f'./checkpoints/{epoch}/'):
            os.makedirs(f'./checkpoints/{epoch}/')
        torch.save(encoder.state_dict(), f'./checkpoints/{epoch}/encoder_weights_{epoch}.pt')
        torch.save(predictor.state_dict(), f'./checkpoints/{epoch}/predictor_weights_{epoch}.pt')
        torch.save(action_conditioner.state_dict(), f'./checkpoints/{epoch}/action_conditioner_weights_{epoch}.pt')
        torch.save(diffusion_model.state_dict(), f'./checkpoints/{epoch}/diffusion_model_weights_{epoch}.pt')

        logger.info('Saved checkpoints.')