In [1]:
import os
import copy
import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
import torch.nn.functional as F
from config import get_config
from datamodule.data_module import DataModule
from models.av_net import AVNet
import numpy as np

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


In [2]:
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '7'
os.environ['CUDA_VISIBLE_DEVICES'] = '6,7'

In [3]:

def is_deepcopyable(obj):
    try:
        copy.deepcopy(obj)
        return True
    except Exception:
        return False


class AVSRModule(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.config = config

        # Convert config dict to a flat dictionary for hyperparameters,
        # but only include items that can be deep-copied.
        hparams = {}
        for section, params in config.items():
            if isinstance(params, dict):
                for key, value in params.items():
                    if is_deepcopyable(value):
                        hparams[f"{section}_{key}"] = value
            else:
                if is_deepcopyable(params):
                    hparams[section] = params
        self.save_hyperparameters(hparams)
        
        # Model initialization
        model_args = (
            config["model"]["d_model"],
            config["model"]["n_heads"],
            config["model"]["n_layers"],
            config["model"]["pe_max_len"],
            config["model"]["fc_hidden_size"],
            config["model"]["dropout"]
        )
        
        self.model = AVNet(
            modal="AV",
            MoCofile=os.path.join(os.getcwd(), config["data"]["moco_file"]),
            reqInpLen=config["model"]["required_input_length"],
            modelargs=model_args
        )
        
        # MSE Loss for feature learning
        self.loss_fn = nn.MSELoss()
        
    def training_step(self, batch, batch_idx):
        # Prepare input data
        input_data = (
            batch["audios"],
            batch["audio_attention_mask"],
            batch["videos"],
            batch["video_attention_mask"]
        )
        
        # Forward pass
        features = self.model(input_data)
        
        # Calculate MSE loss between audio and video features
        loss = self.loss_fn(features, batch["audios"])
        
        # Log training loss
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        
        # Log gate attention weights
        if hasattr(self.model, 'fusion_module'):
            for i, layer in enumerate(self.model.fusion_module.layers):
                self.log(f'train_attn_gate_{i}', layer.attn_gate.item(), on_step=False, on_epoch=True)
                self.log(f'train_ff_gate_{i}', layer.ff_gate.item(), on_step=False, on_epoch=True)
        
        return loss
        
    def validation_step(self, batch, batch_idx):
        # Prepare input data
        input_data = (
            batch["audios"],
            batch["audio_attention_mask"],
            batch["videos"],
            batch["video_attention_mask"]
        )
        
        # Forward pass
        features = self.model(input_data)
        
        # Calculate MSE loss between audio and video features
        loss = self.loss_fn(features, batch["audios"])
        
        # Log validation loss
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        
        # Calculate and log cosine similarity
        cos_sim = F.cosine_similarity(features.mean(1), batch["audios"].mean(1))
        self.log('val_cosine_sim', cos_sim.mean(), on_epoch=True, prog_bar=True)
        
        return loss
        
    def configure_optimizers(self):
        # Separate parameters for weight decay
        param_optimizer = list(self.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight', 'gate']
        optimizer_grouped_parameters = [
            {
                'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
                'weight_decay': self.config["training"]["weight_decay"]
            },
            {
                'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
                'weight_decay': 0.0
            }
        ]
        
        # Create optimizer
        optimizer = torch.optim.AdamW(
            optimizer_grouped_parameters,
            lr=0.0,  # Will be set by the scheduler
            betas=(0.9, 0.98),
            eps=1e-6
        )
        
        # Learning rate scheduler with warmup
        num_training_steps = self.trainer.estimated_stepping_batches
        num_warmup_steps = int(num_training_steps * 0.1)  # 10% warmup
        
        def lr_lambda(current_step):
            if current_step < num_warmup_steps:
                return float(current_step) / float(max(1, num_warmup_steps))
            return max(
                0.0,
                float(num_training_steps - current_step) / 
                float(max(1, num_training_steps - num_warmup_steps))
            )
        
        scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer,
            lr_lambda,
            last_epoch=-1
        )
        
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'interval': 'step',
                'frequency': 1
            }
        }


In [4]:
# Load configuration
config = get_config()

# Initialize data module and model
data_module = DataModule(config)
model = AVSRModule(config)

# Setup callbacks
callbacks = [
    # Model checkpointing
    ModelCheckpoint(
        dirpath=config["output"]["checkpoint_dir"],
        filename='avsr-{epoch:02d}-{val_loss:.2f}',
        save_top_k=config["output"]["save_top_k"],
        monitor=config["output"]["monitor"],
        mode=config["output"]["monitor_mode"]
    ),
    # Early stopping
    EarlyStopping(
        monitor=config["output"]["monitor"],
        patience=config["training"]["early_stopping_patience"],
        mode=config["output"]["monitor_mode"]
    ),
    # Learning rate monitor
    pl.callbacks.LearningRateMonitor(logging_interval='step')
]

  self.visual_model.load_state_dict(torch.load(MoCofile, map_location="cpu"), strict=False)


In [5]:
# Setup logger with detailed metrics
logger = TensorBoardLogger(
    save_dir=config["output"]["log_dir"],
    name='avsr_logs',
    default_hp_metric=False
)
# Initialze trainer with improved settings
trainer = pl.Trainer(
    max_epochs=config["training"]["epochs"],
    callbacks=callbacks,
    logger=logger,
    gradient_clip_val=config["training"]["gradient_clip_val"],
    accumulate_grad_batches=config["training"].get("accumulate_grad_batches", 1),
    precision=16, 
    accelerator='auto',
    devices="auto",
    strategy='ddp_notebook',
    deterministic=False,
    benchmark=True,
    sync_batchnorm=True
)

/opt/miniconda/envs/whisper-flamingo/lib/python3.8/site-packages/lightning_fabric/connector.py:558: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
Trainer will use only 1 of 2 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=2)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
Using 16bit Automatic Mixed Precision (AMP)
/opt/miniconda/envs/whisper-flamingo/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/amp.py:54: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [6]:
trainer.fit(model, data_module)

Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 1 processes
----------------------------------------------------------------------------------------------------

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [6,7]
Loading `train_dataloader` to estimate number of stepping batches.


ProcessExitedException: process 0 terminated with signal SIGSEGV