In [None]:
# Default export removed as this is now interactive-only

# Training Notebook (Interactive)

> Orchestrates the Indic-CLIP model training process using fast.ai's Learner and custom components.
> Designed for interactive execution within a notebook environment (e.g., Colab, Jupyter).

## Load libraries

In [1]:
#| hide
# Mount Google Drive (Optional, but recommended for persistent storage)
from pathlib import Path
import sys
import os

project_parent = '/workspace'
if Path('/workspace/indic-clip').exists():
    project_parent = '/workspace/indic-clip'
    if project_parent not in sys.path:
        sys.path.insert(0, project_parent)
        print(f"Added {project_parent} to sys.path")
    try:
        import indic_clip.core
        print("Imported indic_clip.core after path adjustment.")
    except ModuleNotFoundError:
        print("ERROR: Still cannot find indic_clip.core. Ensure project structure is correct.")
        print("Expected: /workspace/indic-clip/indic-clip/core.py or similar in Drive")          

Added /workspace/indic-clip to sys.path
Imported indic_clip.core after path adjustment.


In [2]:
#| hide
# Install requirements if needed (especially in Colab)
# !pip install -qr requirements.txt

In [3]:
# --- Standard Library Imports ---
import os
import warnings
from pathlib import Path
import math
import sys

# --- Pypi Library Imports ---
import torch
import wandb
import numpy as np
from fastcore.all import *

# --- fastai Imports ---
from fastai.vision.all import *
from fastai.text.all import *
from fastai.data.all import *
from fastai.callback.wandb import *
from fastai.callback.progress import ProgressCallback # Ensure ProgressCallback is imported
from fastai.callback.schedule import ParamScheduler # For LR schedule logging
from fastai.callback.fp16 import MixedPrecision # For AMP
from fastai.callback.training import GradientClip, GradientAccumulation
from fastai.callback.tracker import EarlyStoppingCallback, SaveModelCallback

# --- Project Imports ---
# Use try-except for robustness, especially during development/export
try:
    from indic_clip.core import * # Imports constants and utils
    from indic_clip.data.creation import IndicCLIPDataBlock, get_indic_clip_items
    from indic_clip.data.tokenization import IndicBERTTokenizer
    from indic_clip.model.clip import IndicCLIP
    from indic_clip.loss import ContrastiveLoss
    from indic_clip.learner import RetrievalMetricCallback # Custom callback for validation
    # Import the actual metric function
    from indic_clip.evaluation.metrics import calculate_retrieval_metrics
except ModuleNotFoundError as e:
    print(f"Error importing project modules: {e}")
    print("Please ensure the project is installed or sys.path is configured correctly.")
    # Define dummy classes/functions if imports fail, allowing script structure to be parsed
    class IndicCLIPDataBlock: pass
    def get_indic_clip_items(*args, **kwargs): return []
    class IndicBERTTokenizer: pass
    class IndicCLIP(torch.nn.Module): pass
    class ContrastiveLoss(torch.nn.Module): pass
    class RetrievalMetricCallback(Callback): pass
    def calculate_retrieval_metrics(*args, **kwargs): return {'mean_recall': 0.0, 'i2t_r@1': 0.0, 't2i_r@1': 0.0}

# --- Setup Logging ---
setup_logging()
logger = get_logger(__name__)

Reloaded indic_clip.core
Reloaded indic_clip.core
Reloaded indic_clip.core
Reloaded indic_clip.core
Reloaded indic_clip.core
Reloaded indic_clip.core
Imported indic_clip.core


## Configuration / Hyperparameters

In [4]:
# Define a simple class or use a dictionary to hold configuration
class TrainConfig:
    # --- Data Arguments ---
    processed_data_path: str = str(PROCESSED_DATA_PATH / 'filtered_data.jsonl')
    tokenizer_path: str = str(TOKENIZER_PATH)
    img_size: int = DEFAULT_IMAGE_SIZE
    max_seq_len: int = 128
    valid_pct: float = 0.05
    num_workers: int = 2 # Lower default for Colab compatibility
    use_augmentations: bool = True

    # --- Model Arguments ---
    vision_model_name: str = 'vit_base_patch16_224' # Vision Transformer Base
    text_model_name: str = PRETRAINED_TOKENIZER_NAME # e.g., "ai4bharat/indic-bert"
    vision_pretrained: bool = True
    text_pretrained: bool = True
    embed_dim: int = DEFAULT_EMBED_DIM # Should match vision/text models or projection target

    # --- Training Arguments ---
    epochs: int = 5
    batch_size: int = 32 # Adjust based on GPU memory
    lr: float = 5e-6 # Learning rate for ViT/BERT might need tuning
    wd: float = 0.01 # Weight decay
    beta1: float = 0.9
    beta2: float = 0.98 # AdamW defaults often work well
    eps: float = 1e-6
    warmup_steps: int = 1000 # Number of warmup steps for LR scheduler
    use_amp: bool = True # Use Automatic Mixed Precision
    grad_clip: float | None = None # Gradient clipping value (e.g., 1.0) or None
    accum_freq: int = 1 # Gradient accumulation frequency
    seed: int = 42

    # --- Checkpointing & Resuming ---
    checkpoint_dir: str = str(CHECKPOINT_PATH)
    save_model_name: str = 'best_recall_interactive' # Model saved by SaveModelCallback
    save_epoch_frequency: int = 1 # How often to save checkpoints regardless of metric
    resume_from: str | None = None # Path to a .pth checkpoint to resume
    early_stopping_patience: int = 5 # Epochs to wait for improvement before stopping

    # --- WandB Arguments ---
    wandb_project: str = 'Indic-CLIP-Interactive'
    wandb_entity: str | None = os.getenv('WANDB_ENTITY') # Read from env var
    wandb_run_name: str | None = None # Auto-generate run name if None
    wandb_log_model: str = 'best' # Options: 'best', 'all', 'false'

    # --- Debugging/Misc ---
    max_steps: int | None = None # Limit training steps for debugging

# Instantiate the config - MODIFY VALUES HERE FOR YOUR RUN
config = TrainConfig()

# Example modification for a quick test:
# config.epochs = 1
# config.batch_size = 16
# config.max_steps = 10
# config.wandb_entity = "your_wandb_entity" # IMPORTANT: SET THIS if not using env var
# config.vision_model_name = 'resnet18' # Use smaller model for faster testing

In [5]:
class InputInspectCallback(Callback):
    order = -5 # Run very early, before model prediction
    def before_batch(self):
        if self.training: # Only log during training for now
            logger.critical(f"InputInspectCB - learn.xb type: {type(self.learn.xb)}")
            if isinstance(self.learn.xb, (list, tuple)):
                logger.critical(f"InputInspectCB - learn.xb length: {len(self.learn.xb)}")
                logger.critical(f"InputInspectCB - Elem 0 (Image?) type: {type(self.learn.xb[0])}, shape: {self.learn.xb[0].shape}")
                if len(self.learn.xb) > 1:
                     logger.critical(f"InputInspectCB - Elem 1 (TextTuple?) type: {type(self.learn.xb[1])}")
                     if isinstance(self.learn.xb[1], tuple):
                         logger.critical(f"InputInspectCB - TextTuple length: {len(self.learn.xb[1])}")
                         logger.critical(f"InputInspectCB - TextTuple[0] shape: {self.learn.xb[1][0].shape}")
                         logger.critical(f"InputInspectCB - TextTuple[1] shape: {self.learn.xb[1][1].shape}")
                     else:
                          logger.error("!!! InputInspectCB - Elem 1 is NOT a tuple !!!")
                else:
                     logger.error("!!! InputInspectCB - Only one element in learn.xb !!!")
            else:
                 logger.error(f"!!! InputInspectCB - learn.xb is NOT a tuple/list !!! Type: {type(self.learn.xb)}")
        # Don't CancelBatchException here, just observe

In [6]:
class GradientDebugCallback(Callback):
    def after_backward(self):
        if not self.training: return # Only check during training steps
        grad_norms = []
        all_finite = True
        for name, param in self.learn.model.named_parameters():
            if param.grad is not None:
                if not torch.isfinite(param.grad).all():
                    logger.error(f"!!! Non-finite gradient detected in parameter: {name} !!!")
                    all_finite = False
                    # Optional: Log the grad values
                    # logger.error(f"Grad values: {param.grad.detach().cpu().flatten()[:10]}...") # Log first few values
                    # break # Stop checking after first NaN/Inf grad? Or check all?
                # grad_norms.append(param.grad.norm().item()) # Collect norms if needed
            # else: logger.debug(f"No grad for param: {name}") # Can be noisy

        if not all_finite:
            # Optionally: Zero gradients to prevent optimizer step with NaNs?
            # self.learn.opt.zero_grad()
            # logger.warning("Zeroed gradients due to non-finite values.")
            pass # Just log for now

        # Optional: Log norm distribution if desired
        # if grad_norms:
        #     logger.info(f"Gradient norms - Min: {min(grad_norms):.4e}, Max: {max(grad_norms):.4e}, Mean: {np.mean(grad_norms):.4e}")

In [7]:
class DebugRecorderStateCallback(Callback):
    order = Recorder.order + 1 # Run after Recorder but potentially before SaveModelCallback

    def after_epoch(self):
        if hasattr(self.learn, 'recorder'):
            rec = self.learn.recorder
            epoch = self.learn.epoch
            logger.info(f"--- Debug CB (Epoch {epoch}) BEFORE SaveModel ---")
            logger.info(f"Recorder metric_names: {getattr(rec, 'metric_names', 'N/A')}")
            logger.info(f"Recorder log values: {getattr(rec, 'log', 'N/A')}")
            logger.info(f"Recorder final_record: {getattr(rec, 'final_record', 'N/A')}")
            # Specifically check for valid_loss
            val_loss = getattr(rec, 'final_record', None)
            if val_loss is not None and len(val_loss) > 2: # Assuming train_loss, valid_loss order
                logger.info(f"Value for 'valid_loss' (via final_record[2]): {val_loss[2]}")
            else:
                # Try getting from log dict if final_record isn't available or structured differently
                #log_dict = dict(zip(getattr(rec,'metric_names',[]), getattr(rec, 'log', [])))
                vals = rec.values[-1]
                log_dict = {n: float(v) for n,v in zip(rec.metric_names, vals)}
                logger.info(f"Value for 'valid_loss' (via log dict): {log_dict.get('valid_loss', 'Not Found')}")

            logger.info(f"--- End Debug CB ---")
        else:
            logger.warning("Debug CB: Recorder not found.")

In [8]:
class SimpleWandbCallback(Callback):
    order = Recorder.order + 1 # Run after Recorder to ensure metrics are calculated and logged
    remove_on_fetch = True

    def __init__(self, log_model_policy='best'):
        store_attr()
        self._wandb_step = 0
        self._wandb_epoch = 0
        self.run = None # Initialize run attribute

    def _find_recorder_cb(self):
        "Helper to find Recorder callback."
        if not hasattr(self, 'learn') or not hasattr(self.learn, 'cbs'): return None
        for cb in self.learn.cbs:
            if isinstance(cb, Recorder): return cb
        return None

    def _find_save_model_cb(self):
        "Helper to find SaveModelCallback."
        if not hasattr(self, 'learn') or not hasattr(self.learn, 'cbs'): return None
        for cb in self.learn.cbs:
            if isinstance(cb, SaveModelCallback): return cb
        return None

    def before_fit(self):
        "Initialize W&B run, log config, watch model, and add model save hook."
        # Ensure wandb is initialized (it might have been initialized outside)
        self.run = wandb.run
        if self.run is None:
            try:
                # Get config from the Learner if available, otherwise log empty
                # Use getattr for safety in case learn.train_config doesn't exist
                cfg_dict = getattr(self.learn, 'train_config', {})
                # If cfg_dict is a dataclass/object, convert to dict if possible
                if not isinstance(cfg_dict, dict) and hasattr(cfg_dict, '__dict__'):
                    cfg_dict = cfg_dict.__dict__
                elif not isinstance(cfg_dict, dict):
                     cfg_dict = {} # Fallback to empty dict

                wandb_mode = "online" if cfg_dict.get('wandb_entity') else "disabled"

                # Ensure wandb.run is None before potentially re-initializing
                # This helps prevent issues in notebooks if cells are re-run
                if wandb.run is not None:
                     logger.warning("Existing wandb run detected. Finishing it before initializing a new one.")
                     wandb.finish()

                self.run = wandb.init(project=cfg_dict.get('wandb_project', 'fastai-project'),
                                      entity=cfg_dict.get('wandb_entity'),
                                      name=cfg_dict.get('wandb_run_name'),
                                      config=cfg_dict, # Log the config object used by main()
                                      reinit=True, # Allow re-initialization
                                      mode=wandb_mode)
                logger.info(f"WandB run initialized: {self.run.name if self.run else 'N/A'} (mode: {wandb.run.mode if self.run else 'N/A'})")
            except Exception as e:
                logger.error(f"Failed to initialize WandB: {e}. Logging disabled.")
                self.run = None # Ensure self.run is None if init fails

        # Reset step/epoch counters
        self._wandb_step = 0
        self._wandb_epoch = 0

        # Watch model if run is active
        if self.run and self.run.mode != "disabled":
            try:
                wandb.watch(self.learn.model, log='all', log_freq=max(100, getattr(self.learn.dls,'train_ds',None) or 1000)//self.learn.dls.bs) # Log gradients/params less frequently
                logger.info("WandB watching model.")
            except Exception as e:
                logger.warning(f"Could not watch model with WandB: {e}")

        # Add model save hook if policy requires it
        if self.log_model_policy != 'false' and self.run and self.run.mode != "disabled":
            save_model_cb = self._find_save_model_cb()
            if save_model_cb:
                 # Check for the hook attribute *before* trying to call it
                 if hasattr(save_model_cb, 'add_save_hooks'):
                     try:
                         save_model_cb.add_save_hooks(self._wandb_log_model)
                         logger.info("Added WandB model logging hook to SaveModelCallback.")
                     except Exception as e:
                         logger.error(f"Unexpected error attaching save hook: {e}", exc_info=True)
                 else:
                    # Log the warning about missing hook method
                    logger.warning(
                      "SaveModelCallback found, but it does not have the 'add_save_hooks' method; skipping W&B model hook."
                    )
            else:
                logger.warning(
                  "SaveModelCallback not found; cannot attach W&B model logging hook."
                )


    def after_batch(self):
        "Log training loss and hyperparameters to W&B after each training batch."
        if self.training and self.run and self.run.mode != "disabled":
            self._wandb_step += 1
            # Correctly log learning rate and other hypers
            hypers = {}
            if hasattr(self.learn.opt, 'hypers'):
                hypers = {f'hp_{i}_{k}': v for i, h in enumerate(self.learn.opt.hypers) for k, v in h.items()}
            elif hasattr(self.learn.opt, 'param_groups'): # Fallback for some optimizers
                hypers = {f'hp_{i}_lr': pg['lr'] for i, pg in enumerate(self.learn.opt.param_groups)}


            # Get smoothed loss from Recorder if available, else raw loss
            smooth_loss = self.learn.smooth_loss.item() if hasattr(self.learn,'smooth_loss') and getattr(self.learn.smooth_loss, 'is_valid', True) else self.learn.loss.item()
            raw_loss = self.learn.loss.item()
            log_data = {
                'epoch_frac': self._wandb_epoch + self.learn.pct_train, # Log fractional epoch
                'train_loss': smooth_loss if not math.isnan(smooth_loss) else 0.0,
                'raw_loss': raw_loss if not math.isnan(raw_loss) else 0.0,
                **hypers
            }
            wandb.log(log_data, step=self._wandb_step)

    def before_epoch(self):
        "Update internal epoch counter."
        self._wandb_epoch = self.learn.epoch # Update epoch number

    # (inside SimpleWandbCallback)
    def after_epoch(self):
        rec = next((cb for cb in self.learn.cbs if isinstance(cb, Recorder)), None)
        if rec is None or not rec.values: return
        names = rec.metric_names
        vals  = rec.values[-1]
        log_dict = {n: float(v) for n,v in zip(names, vals)}
        log_dict['epoch'] = self._wandb_epoch
        wandb.log(log_dict, step=self._wandb_step)
        logger.info(f"Logged to W&B: {log_dict}")

    def _wandb_log_model(self, learn, file, **kwargs):
        "Hook function called by SaveModelCallback to log model artifacts to W&B."
        if self.run is None or self.run.mode == "disabled": return # Don't log if wandb is disabled

        save_model_cb = self._find_save_model_cb()
        if not save_model_cb:
            logger.error("SaveModelCallback not found in _wandb_log_model hook.")
            return

        try:
            # Check if this save corresponds to a new best model
            # Check both `new_best` (from >=2.7) and `is_new_best` (older tracker)
            is_new_best = getattr(save_model_cb, 'new_best', getattr(save_model_cb, 'is_new_best', False))
            should_log = self.log_model_policy == 'all' or (self.log_model_policy == 'best' and is_new_best)

            if should_log:
                # Get metadata from recorder if possible
                recorder_cb = self._find_recorder_cb()
                metadata = {}
                if recorder_cb and hasattr(recorder_cb, 'final_record') and hasattr(recorder_cb, 'metric_names'):
                    # Use final_record for accurate end-of-epoch values
                     metadata = { n:f'{v:.5f}' for n,v in zip(recorder_cb.metric_names, recorder_cb.final_record) if n not in ['epoch', 'time']}
                elif recorder_cb and hasattr(recorder_cb, 'log') and hasattr(recorder_cb, 'metric_names'):
                     # Fallback to log if final_record not available
                     logger.warning("Using recorder.log for artifact metadata (may not be final values).")
                     # Attempt to map log values; indices might be incorrect (see after_epoch)
                     # This part is less reliable.
                     valid_loss_idx = recorder_cb.metric_names.index('valid_loss') if 'valid_loss' in recorder_cb.metric_names else -1
                     if valid_loss_idx != -1 and len(recorder_cb.log) > 1:
                         metadata['valid_loss'] = f'{recorder_cb.log[1]:.5f}' # Assuming log[1] is valid_loss

                aliases = [f'epoch_{learn.epoch}']
                if is_new_best: aliases.append('best')

                fname = getattr(save_model_cb, 'fname', 'model') # fname is usually just the base name like 'best_model'
                artifact_name = f'{self.run.name}_{fname}' if self.run.name else fname
                artifact = wandb.Artifact(name=artifact_name, type='model', metadata=metadata)

                # `file` passed by the hook is the full path to the saved file
                model_path = Path(file)
                if model_path.is_file():
                    artifact.add_file(model_path)
                    self.run.log_artifact(artifact, aliases=aliases)
                    logger.info(f"Logged model artifact '{artifact_name}' to WandB with aliases {aliases}.")
                else:
                    logger.error(f"Model checkpoint file not found at: {model_path}")

        except Exception as e:
            logger.error(f"Failed to log model artifact to WandB: {e}", exc_info=True)


    def after_fit(self):
        "Log best metric value and finish W&B run."
        if self.run and self.run.mode != "disabled":
            save_model_cb = self._find_save_model_cb()
            # Log best metric value if available
            if save_model_cb and hasattr(save_model_cb, 'best') and save_model_cb.best is not None:
                monitor_metric = save_model_cb.monitor
                wandb.summary[f'best_{monitor_metric}'] = save_model_cb.best
                logger.info(f"Logged best {monitor_metric} to WandB summary: {save_model_cb.best:.4f}")

            # Remove hook - check if hook exists before removing
            if self.log_model_policy != 'false' and save_model_cb and hasattr(save_model_cb, 'remove_save_hooks'):
                 try:
                     # Check if the hook is actually present before removing
                     if hasattr(save_model_cb, 'save_hooks') and self._wandb_log_model in save_model_cb.save_hooks:
                         save_model_cb.remove_save_hooks(self._wandb_log_model)
                         logger.info("Removed WandB save hook.")
                     else:
                         logger.info("WandB save hook not found or already removed.")

                 except Exception as e:
                     logger.warning(f"Could not remove WandB save hook: {e}")

            # Check if wandb.run is still active before finishing
            if wandb.run is not None:
                wandb.finish()
                logger.info("WandB run finished.")
            else:
                logger.info("WandB run already finished or was never initialized.")

        # Reset internal state
        self._wandb_step = 0
        self._wandb_epoch = 0
        self.run = None # Clear the run object


In [9]:
try:
    from indic_clip.learner import _retrieval_metric_values
except (ImportError, NameError):
     # Define if not importable (e.g. callback defined in notebook)
     _retrieval_metric_values = {}

def valid_mean_recall(inp=None, targ=None): return tensor(_retrieval_metric_values.get('mean_recall', 0.0))
def valid_i2t_r_at_1(inp=None, targ=None): return tensor(_retrieval_metric_values.get('i2t_r@1', 0.0))
def valid_t2i_r_at_1(inp=None, targ=None): return tensor(_retrieval_metric_values.get('t2i_r@1', 0.0))
# Add R@5, R@10 etc. if needed
def valid_i2t_r_at_5(inp=None, targ=None): return tensor(_retrieval_metric_values.get('i2t_r@5', 0.0))
def valid_t2i_r_at_5(inp=None, targ=None): return tensor(_retrieval_metric_values.get('t2i_r@5', 0.0))
def valid_i2t_r_at_10(inp=None, targ=None): return tensor(_retrieval_metric_values.get('i2t_r@10', 0.0))
def valid_t2i_r_at_10(inp=None, targ=None): return tensor(_retrieval_metric_values.get('t2i_r@10', 0.0))

In [10]:
# Function to encapsulate the training logic
def main(config: TrainConfig):
    """Main function to setup and run the training process using the config object."""
    # ... (Basic Setup, Load Data, Instantiate Model, Loss, Optimizer remains largely the same) ...
    set_seed(config.seed)
    ensure_dir(Path(config.checkpoint_dir))
    logger.info(f"Starting training run with config: {config.__dict__}")

    # --- Load Data ---\n" +
    logger.info(f"Loading data items from: {config.processed_data_path}")
    items_df = get_indic_clip_items(data_path=Path(config.processed_data_path))
    if items_df.empty:
        logger.error("No data items loaded. Exiting.")
        if wandb.run: wandb.finish(exit_code=1)
        return

    logger.info(f"Instantiating DataBlock...")
    tokenizer = IndicBERTTokenizer.load_tokenizer(Path(config.tokenizer_path), max_length=config.max_seq_len)
    indic_clip_dblock = IndicCLIPDataBlock(
        tokenizer_name_or_path=config.text_model_name,
        tokenizer_save_path=Path(config.tokenizer_path),
        max_length=config.max_seq_len,
        img_size=config.img_size,
        valid_pct=config.valid_pct,
        seed=config.seed,
        batch_size=config.batch_size,
        num_workers=config.num_workers,
        use_augmentations=config.use_augmentations
    )
    logger.info(f"Creating DataLoaders...")
    dls = indic_clip_dblock.get_dataloaders(items_df)
    logger.info(f"DataLoaders created. Train batches: {len(dls.train)}, Valid batches: {len(dls.valid)}")

    # --- Instantiate Model ---\n" +
    logger.info("Instantiating IndicCLIP model...")
    model = IndicCLIP(
        embed_dim=config.embed_dim,
        vision_model_name=config.vision_model_name,
        vision_pretrained=config.vision_pretrained,
        text_model_name=config.text_model_name,
        text_pretrained=config.text_pretrained,
        tokenizer=tokenizer
    )

    # --- Loss Function ---\n" +
    metrics = [
            valid_mean_recall,
            valid_i2t_r_at_1,
            valid_t2i_r_at_1,
            valid_i2t_r_at_5,
            valid_t2i_r_at_5,
            valid_i2t_r_at_10,
            valid_t2i_r_at_10,
            ]

    # --- Optimizer ---\n" +
    opt_func = partial(OptimWrapper, opt=torch.optim.AdamW,
                       betas=(config.beta1, config.beta2), eps=config.eps)

    # --- Callbacks ---\n" +
    logger.info("Configuring callbacks...")

    # Create standard callbacks first
    save_cb = SaveModelCallback(
            monitor='valid_loss',
            comp=np.less,
            fname=config.save_model_name,
            # every_epoch=(config.save_epoch_frequency > 0 and config.save_epoch_frequency <= config.epochs),
            at_end=True,
            with_opt=True
            )
    early_stop_cb = EarlyStoppingCallback(
            monitor='valid_loss',
            comp=np.less,
            patience=config.early_stopping_patience
            )
    retrieval_cb = RetrievalMetricCallback(k_values=[1, 5, 10])

    callbacks = [
        # InputInspectCallback(),
        # GradientDebugCallback(),
        # Recorder() is added automatically by Learner if metrics are present
        #retrieval_cb,
        # DebugRecorderStateCallback(),
        # save_cb,
        # early_stop_cb,
        # ProgressCallback()
    ]

    # Add optional callbacks
    if config.use_amp:
        logger.info("Using Automatic Mixed Precision (AMP).")
        callbacks.insert(0, MixedPrecision())
    if config.accum_freq > 1:
        logger.info("Using Gradient Accumulation with frequency {config.accum_freq}.")
        callbacks.insert(0, GradientAccumulation(n_acc=config.accum_freq))
    if config.grad_clip is not None:
         logger.info("Using Gradient Clipping with value {config.grad_clip}.")
         callbacks.append(GradientClip(config.grad_clip))

    # Create SimpleWandbCallback and add the hook to SaveModelCallback
    # Do this *before* creating the Learner so hooks are set
    simple_wandb_cb = SimpleWandbCallback(log_model_policy=config.wandb_log_model)
    # Guard against SaveModelCallback not supporting add_save_hooks
    if config.wandb_log_model != 'false' and hasattr(save_cb, 'add_save_hooks'):
        try:
            save_cb.add_save_hooks(simple_wandb_cb._wandb_log_model)
            logger.info("Added WandB save hook to SaveModelCallback.")
        except Exception as e:
            logger.error(f"Unexpected error attaching save hook: {e}", exc_info=True)
    else:
        logger.warning("SaveModelCallback has no add_save_hooks—skipping W&B model hook.")

    simple_wandb_cb.order = max(SaveModelCallback.order, DebugRecorderStateCallback.order) + 1
    callbacks.append(simple_wandb_cb)

    # --- Create Learner ---\n" +
    logger.info("Creating fastai Learner...")
    loss_func = ContrastiveLoss()
    logger.info(f"Using loss function: {type(loss_func)}")
    learn = Learner(dls, model, loss_func=loss_func, opt_func=opt_func,
                    wd=config.wd, cbs=callbacks, metrics=None)

    # logger.info(f"Learner Summary: \n {learn.summary()}")

    # Handle resuming from checkpoint (load state into the created learner)
    if config.resume_from:
        resume_path = Path(config.resume_from)
        if resume_path.is_file():
            try:
                logger.info(f"Resuming training from checkpoint: {resume_path}")
                learn.load(resume_path.stem, with_opt=True, device=dls.device)
            except Exception as e:
                logger.error(f"Failed to load checkpoint {resume_path}: {e}. Starting from scratch.")
        else:
            logger.warning(f"Resume checkpoint not found at {resume_path}. Starting from scratch.")

    # --- Start Training ---\n" +
    logger.info(f"Starting training for {config.epochs} epochs...")
    if len(dls.train) == 0:
         logger.error("Training dataloader is empty. Cannot calculate steps.")
         if wandb.run: wandb.finish(exit_code=1)
         return

    total_steps = len(dls.train) * config.epochs
    actual_epochs = config.epochs

    if config.max_steps:
        logger.warning(f"Limiting training to a maximum of {config.max_steps} steps." )
        epochs_to_run = math.ceil(config.max_steps / len(dls.train))
        actual_epochs = min(config.epochs, int(epochs_to_run))
        logger.info(f"Adjusted epochs to {actual_epochs} based on max_steps.")
        # Add TrainStepsCallback (if not already present)
        # Need to recreate learner or add callback dynamically *before* fit
        has_steps_cb = any(isinstance(cb, TrainStepsCallback) for cb in learn.cbs)
        if not has_steps_cb:
            learn.add_cb(TrainStepsCallback(config.max_steps))
            logger.info("Added TrainStepsCallback.")

    print(f"loss function: {learn.loss_func}")
    # Run training
    learn.fit_one_cycle(
        n_epoch=actual_epochs,
        lr_max=config.lr,
        pct_start=min(0.3, config.warmup_steps / total_steps) if total_steps > 0 else 0.1
    )

    logger.info("Training finished.")
    # Ensure WandB run is finished cleanly
    if wandb.run is not None:
         wandb.finish()
         logger.info("WandB run finished.")

## Run Training (Interactive)

In [None]:
# --- Interactive Execution Block (for Notebooks) ---

# Check if running interactively (not as a script imported by another module)
if __name__ == '__main__' and '__file__' not in globals():
    print("Running training interactively from notebook...")

    # Instantiate the config object defined in the cell above
    # --- >>>> IMPORTANT: Modify the TrainConfig class definition cell <<<< ---
    # --- >>>>        directly to change hyperparameters for your run <<<< ---
    config = TrainConfig()

    # --- Example Overrides for a quick test ---
    config.epochs = 1
    config.batch_size = 32 # Reduced batch size for faster iteration/memory
    # config.max_steps = 10 # Uncomment to run only a few steps
    # config.wandb_entity = "numb3r33" # <<< SET YOUR WANDB ENTITY HERE >>>
    config.wandb_entity = None # <<< SET YOUR WANDB ENTITY HERE >>>
    
    config.vision_model_name = 'resnet18' # Use a smaller vision model for testing
    config.embed_dim = 512 # Adjust embed_dim to match ResNet18 output
    config.num_workers = 16 # Often safer for interactive debugging
    config.valid_pct = 0.25 # Use a bit more validation data for testing
    config.use_amp = False # Use Automatic Mixed Precision

    # ----------------------------------------

    if not config.wandb_entity:
        print("*** Warning: WandB Entity is not set in config or environment. ***")
        print("*** WandB logging will be disabled. Set config.wandb_entity or WANDB_ENTITY env var. ***")

    print(f"Starting interactive run with config:")
    # Pretty print config
    for key, value in config.__dict__.items():
        print(f"  {key}: {value}")
    print("-" * 30)

    try:
       main(config=config) # Pass the config object to main
    except Exception as e:
       logger.error(f"Interactive run failed: {e}", exc_info=True)
       # Ensure wandb is finished even on error
       if wandb.run is not None and wandb.run.mode != "disabled":
           wandb.finish(exit_code=1)
    
    print("\nInteractive execution cell finished.")

2025-04-22 05:56:31 - __main__ - INFO - Starting training run with config: {'epochs': 1, 'batch_size': 32, 'wandb_entity': None, 'vision_model_name': 'resnet18', 'embed_dim': 512, 'num_workers': 16, 'valid_pct': 0.25, 'use_amp': False}
2025-04-22 05:56:31 - __main__ - INFO - Loading data items from: /workspace/indic-clip/data/processed/filtered_data.jsonl
2025-04-22 05:56:31 - __main__ - INFO - Instantiating DataBlock...


Running training interactively from notebook...
*** WandB logging will be disabled. Set config.wandb_entity or WANDB_ENTITY env var. ***
Starting interactive run with config:
  epochs: 1
  batch_size: 32
  wandb_entity: None
  vision_model_name: resnet18
  embed_dim: 512
  num_workers: 16
  valid_pct: 0.25
  use_amp: False
------------------------------
Loaded 8006 items from /workspace/indic-clip/data/processed/filtered_data.jsonl


2025-04-22 05:56:31 - indic_clip.data.tokenization - INFO - Successfully loaded tokenizer: /workspace/indic-clip/models/tokenizer
2025-04-22 05:56:32 - indic_clip.data.tokenization - INFO - Custom special tokens already exist or none were specified.
2025-04-22 05:56:32 - indic_clip.data.tokenization - INFO - Tokenizer state loaded successfully from /workspace/indic-clip/models/tokenizer
2025-04-22 05:56:32 - indic_clip.data.tokenization - INFO - Successfully loaded tokenizer: /workspace/indic-clip/models/tokenizer
2025-04-22 05:56:32 - indic_clip.data.tokenization - INFO - Custom special tokens already exist or none were specified.
2025-04-22 05:56:32 - indic_clip.data.tokenization - INFO - Tokenizer state loaded successfully from /workspace/indic-clip/models/tokenizer
2025-04-22 05:56:32 - __main__ - INFO - Creating DataLoaders...


Creating DataLoaders with bs=32, num_workers=16


2025-04-22 05:56:33 - __main__ - INFO - DataLoaders created. Train batches: 187, Valid batches: 63
2025-04-22 05:56:33 - __main__ - INFO - Instantiating IndicCLIP model...
2025-04-22 05:56:33 - timm.models._builder - INFO - Loading pretrained weights from Hugging Face hub (timm/resnet18.a1_in1k)
2025-04-22 05:56:33 - timm.models._hub - INFO - [timm/resnet18.a1_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
2025-04-22 05:56:33 - indic_clip.model.vision - INFO - Loaded timm model: resnet18 with pretrained=True
2025-04-22 05:56:33 - indic_clip.model.vision - INFO - Backbone feature dimension: 512
2025-04-22 05:56:34 - indic_clip.model.text - INFO - Loading text model: ai4bharat/indic-bert with pretrained=True
2025-04-22 05:56:34 - indic_clip.model.text - INFO - Model hidden dimension: 768
2025-04-22 05:56:35 - indic_clip.model.text - INFO - Model embedding size resized to 200002
2025-04-22 05:56:35 - indic_clip.model.c

loss function: ContrastiveLoss()


epoch,train_loss,valid_loss,time


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

In [None]:
#| hide
# Export the notebook using nbdev (Optional - this notebook is now interactive-only)
if __name__ == '__main__' and '__file__' not in globals(): # Only run in notebook context
    pass # No export needed for interactive-only notebook
    # print("Skipping export for interactive-only notebook.")
    # try:
    #     import nbdev.export
    #     print("Please run 'nbdev_export' in your terminal from the project root directory to export other notebooks.")
    # except ImportError:
    #     print("nbdev not found. Run 'pip install nbdev' to use nbdev features.")
    # except Exception as e:
    #     print(f"Error during nbdev check: {e}")