# üé¨ YOWO Multi-Task Training on Google Colab

**Model**: `yowo_v2_x3d_m_yolo11m_multitask`  
**Dataset**: Charades + Action Genome (288K keyframes, 219 classes)

### Optimized Batch Sizes (with AMP)

| GPU | VRAM | Batch | Accum | Effective | Est. Time/Epoch |
|-----|------|-------|-------|-----------|-----------------|
| T4 | 16GB | 8 | 4 | 32 | ~4 hours |
| L4 | 24GB | 12 | 4 | 48 | ~2.5 hours |
| V100 | 16GB | 10 | 4 | 40 | ~2 hours |
| A100 | 40GB | 32 | 2 | 64 | ~50 min |
| A100 | 80GB | 64 | 2 | 128 | ~30 min |
| H100 | 80GB | 80 | 2 | 160 | ~20 min |

**Features**: AMP (FP16), Multi-head (Objects + Actions + Relationships)


In [None]:
# Cell 1: Check GPU & Auto-Configure Batch Size
import torch
print("=" * 70)
print("üîç GPU Detection & Configuration")
print("=" * 70)

if not torch.cuda.is_available():
    raise RuntimeError("‚ùå No GPU! Go to Runtime > Change runtime type > GPU")

gpu_name = torch.cuda.get_device_name(0)
gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / 1e9

print(f"‚úÖ GPU: {gpu_name}")
print(f"‚úÖ VRAM: {gpu_memory_gb:.1f} GB")

# =============================================================================
# OPTIMIZED BATCH SIZES FOR YOWO V2 + X3D-M + YOLO11m WITH AMP
# Based on empirical testing of video action detection models
# AMP reduces memory by ~40%, allowing larger batches
# =============================================================================
if "A100" in gpu_name or "A100" in gpu_name.upper():
    if gpu_memory_gb > 45:  # A100 80GB
        BATCH_SIZE, ACCUMULATE = 64, 2   # Effective: 128 (can try 80 if stable)
    else:  # A100 40GB
        BATCH_SIZE, ACCUMULATE = 32, 2   # Effective: 64 (can try 40-48)
elif "H100" in gpu_name:
    BATCH_SIZE, ACCUMULATE = 80, 2       # Effective: 160 (can try 96)
elif "L4" in gpu_name:
    BATCH_SIZE, ACCUMULATE = 12, 4       # Effective: 48
elif "T4" in gpu_name:
    BATCH_SIZE, ACCUMULATE = 8, 4        # Effective: 32 (can try 10)
elif "V100" in gpu_name:
    BATCH_SIZE, ACCUMULATE = 10, 4       # Effective: 40
elif "P100" in gpu_name:
    BATCH_SIZE, ACCUMULATE = 6, 4        # Effective: 24
else:
    # Unknown GPU - use conservative settings based on memory
    if gpu_memory_gb >= 40:
        BATCH_SIZE, ACCUMULATE = 32, 2
    elif gpu_memory_gb >= 20:
        BATCH_SIZE, ACCUMULATE = 12, 4
    else:
        BATCH_SIZE, ACCUMULATE = 8, 4

effective = BATCH_SIZE * ACCUMULATE
print(f"\nüì¶ Optimized for {gpu_name}:")
print(f"   batch_size = {BATCH_SIZE}")
print(f"   accumulate = {ACCUMULATE}")
print(f"   effective_batch = {effective}")
print(f"\nüí° If OOM: reduce BATCH_SIZE by 2, increase ACCUMULATE proportionally")
print("=" * 70)


In [None]:
# Cell 3: Clone Repository & Install Dependencies
%cd /content
!rm -rf yowo
!git clone https://github.com/michelsedgh/yowo.git
%cd yowo
!pip install -q torch torchvision opencv-python thop scipy matplotlib numpy imageio pytorchvideo ultralytics tensorboard
print("‚úÖ Repository cloned and dependencies installed!")


In [None]:
# Cell 4: Download Annotations & Extract Frames
import os, time, requests, zipfile

DATA_ROOT = "/content/yowo/data/ActionGenome"
FRAMES_DIR = os.path.join(DATA_ROOT, "frames")
ANN_DIR = os.path.join(DATA_ROOT, "annotations")
TAR_PATH = "/content/drive/MyDrive/yooowo/frames.tar"

os.makedirs(ANN_DIR, exist_ok=True)

# =============================================================================
# STEP 1: Download Action Genome annotations (PKL files NOT in git repo!)
# =============================================================================
print("=" * 60)
print("üì• STEP 1: Downloading Action Genome Annotations")
print("=" * 60)

def download_file(url, filepath):
    if os.path.exists(filepath):
        size = os.path.getsize(filepath) / 1e6
        print(f"   ‚úÖ {os.path.basename(filepath)} exists ({size:.1f} MB)")
        return True
    print(f"   Downloading {os.path.basename(filepath)}...")
    try:
        response = requests.get(url, stream=True, timeout=120)
        if response.status_code == 200:
            with open(filepath, 'wb') as f:
                for chunk in response.iter_content(chunk_size=8192):
                    f.write(chunk)
            size = os.path.getsize(filepath) / 1e6
            print(f"   ‚úÖ Downloaded ({size:.1f} MB)")
            return True
    except Exception as e:
        print(f"   ‚ùå Failed: {e}")
    return False

# Action Genome annotations from STAR Benchmark S3
ag_files = {
    'object_bbox_and_relationship.pkl': 'https://star-benchmark.s3.us-east.cloud-object-storage.appdomain.cloud/Annotations/object_bbox_and_relationship.pkl',
    'person_bbox.pkl': 'https://star-benchmark.s3.us-east.cloud-object-storage.appdomain.cloud/Annotations/person_bbox.pkl',
    'classes.zip': 'https://star-benchmark.s3.us-east.cloud-object-storage.appdomain.cloud/Annotations/classes.zip'
}

for filename, url in ag_files.items():
    download_file(url, os.path.join(ANN_DIR, filename))

# Extract classes.zip if needed
classes_zip = os.path.join(ANN_DIR, 'classes.zip')
if os.path.exists(classes_zip) and not os.path.exists(os.path.join(ANN_DIR, 'object_classes.txt')):
    print("   Extracting classes.zip...")
    with zipfile.ZipFile(classes_zip, 'r') as z:
        z.extractall(ANN_DIR)
    # Move files from classes/ subdirectory if needed
    classes_subdir = os.path.join(ANN_DIR, 'classes')
    if os.path.exists(classes_subdir):
        import shutil
        for f in os.listdir(classes_subdir):
            shutil.move(os.path.join(classes_subdir, f), os.path.join(ANN_DIR, f))
        shutil.rmtree(classes_subdir)
    print("   ‚úÖ Extracted class files")



In [None]:
import os
import subprocess
import google.auth
from google.colab import auth
from google.auth.transport.requests import Request
from google.oauth2 import credentials

# ==============================================================================
# CONFIGURATION
# ==============================================================================
# 1. The Main Archive
TAR_FILE_ID = "1GuRdUMP5qrqyYN0gg8C2B6tLwJeigyFd"  
LOCAL_TAR = "/content/frames.tar"

# 2. The Pre-made Index (To save time!)
INDEX_FILE_ID = "1ecTAlWCWWSfSavneBwlALjhocl3LKXoa"
LOCAL_INDEX = "/content/frames.tar.index.sqlite"

# 3. Paths
# We mount the raw tar here first
TEMP_MOUNT_POINT = "/content/raw_mount" 
# We want the data to appear here eventually
FINAL_TARGET_DIR = "/content/yowo/data/ActionGenome/frames"
# ==============================================================================

def install_tools():
    print("üõ†Ô∏è Installing aria2 and ratarmount...")
    subprocess.run(["apt-get", "install", "-y", "-qq", "aria2"], check=True)
    subprocess.run(["pip", "install", "-q", "ratarmount"], check=True)

def get_token():
    print("üîë Authenticating...")
    auth.authenticate_user()
    creds, _ = google.auth.default()
    creds.refresh(Request())
    return creds.token

def download_file(token, file_id, output_path):
    if os.path.exists(output_path):
        print(f"‚úÖ Found existing file: {output_path}")
        return

    print(f"‚¨áÔ∏è Downloading {os.path.basename(output_path)}...")
    url = f"https://www.googleapis.com/drive/v3/files/{file_id}?alt=media"
    
    cmd = [
        "aria2c", "-x", "16", "-s", "16", "-j", "16",
        "--file-allocation=none", "--summary-interval=10",
        "--header", f"Authorization: Bearer {token}", 
        "-o", os.path.basename(output_path),
        "-d", os.path.dirname(output_path),
        url
    ]
    
    process = subprocess.Popen(cmd)
    process.wait()
    
    if process.returncode != 0:
        raise Exception(f"Failed to download {output_path}")

def mount_and_link():
    print(f"\nüîó Mounting archive to temp location: {TEMP_MOUNT_POINT}")
    
    # 1. Cleanup
    subprocess.run(["fusermount", "-u", TEMP_MOUNT_POINT], stderr=subprocess.DEVNULL)
    if os.path.islink(FINAL_TARGET_DIR):
        os.unlink(FINAL_TARGET_DIR)
    elif os.path.exists(FINAL_TARGET_DIR):
        # If it's an empty dir, remove it so we can link
        try: os.rmdir(FINAL_TARGET_DIR)
        except: pass

    os.makedirs(TEMP_MOUNT_POINT, exist_ok=True)
    
    # 2. Ratarmount using the downloaded index
    # We pass the index file explicitly
    cmd = f"ratarmount -P 4 --index-file '{LOCAL_INDEX}' '{LOCAL_TAR}' '{TEMP_MOUNT_POINT}'"
    exit_code = os.system(cmd)
    
    if exit_code != 0:
        raise Exception("Ratarmount failed!")

    # 3. Find the internal data path and Link it
    # Based on your error, the data is nested inside:
    nested_path = os.path.join(TEMP_MOUNT_POINT, "data/ActionGenome/frames")
    
    # Fallback: If that exact path doesn't exist, list folders to help debug
    if not os.path.exists(nested_path):
        print(f"‚ö†Ô∏è Could not find expected path: {nested_path}")
        print(f"üìÇ Contents of root mount: {os.listdir(TEMP_MOUNT_POINT)}")
        # Try to find 'frames' folder dynamically?
        # For now, let's assume the structure you mentioned is correct.
    
    # 4. Create the final destination link
    # Ensure parent dir exists
    parent_dir = os.path.dirname(FINAL_TARGET_DIR)
    os.makedirs(parent_dir, exist_ok=True)
    
    print(f"üîó Linking '{nested_path}' --> '{FINAL_TARGET_DIR}'")
    os.symlink(nested_path, FINAL_TARGET_DIR)
    
    # 5. Verify
    if os.path.exists(FINAL_TARGET_DIR) and len(os.listdir(FINAL_TARGET_DIR)) > 0:
        count = len(os.listdir(FINAL_TARGET_DIR))
        print(f"üéâ SUCCESS! {count} items visible at {FINAL_TARGET_DIR}")
    else:
        print("‚ùå Something went wrong. The target folder is empty.")

# --- EXECUTION ---
try:
    try:
        from google.colab import drive
        drive.flush_and_unmount()
    except: pass
    
    install_tools()
    token = get_token()
    
    # Download Tar AND Index
    download_file(token, TAR_FILE_ID, LOCAL_TAR)
    download_file(token, INDEX_FILE_ID, LOCAL_INDEX)
    
    mount_and_link()

except Exception as e:
    print(f"\n‚ùå CRITICAL ERROR: {e}")

In [None]:
# Cell 5: Verify Dataset Structure
import os, pickle

ANN_DIR = "/content/yowo/data/ActionGenome/annotations"
FRAMES_DIR = "/content/yowo/data/ActionGenome/frames"

print("=" * 60)
print("üîç Dataset Verification")
print("=" * 60)

# Check required files
required_files = {
    'person_bbox.pkl': 'Person bounding boxes + keyframes',
    'object_bbox_and_relationship.pkl': 'Objects + relationships',
    'Charades_v1_train.csv': 'Training action labels',
    'Charades_v1_test.csv': 'Test action labels',
    'Charades_v1_classes.txt': '157 action classes',
    'object_classes.txt': '36 object classes',
    'relationship_classes.txt': '26 relationship classes',
    'video_fps.json': 'FPS for each video'
}

print("\nüìã Required Annotation Files:")
all_ok = True
for f, desc in required_files.items():
    path = os.path.join(ANN_DIR, f)
    if os.path.exists(path):
        size = os.path.getsize(path) / 1e6
        print(f"   ‚úÖ {f} ({size:.1f} MB) - {desc}")
    else:
        print(f"   ‚ùå {f} - MISSING! ({desc})")
        all_ok = False

# Check frames
print(f"\nüìÇ Frames Directory:")
if os.path.exists(FRAMES_DIR):
    num_videos = len(os.listdir(FRAMES_DIR))
    print(f"   ‚úÖ {num_videos} video directories")
    # Sample a video
    sample_vid = os.listdir(FRAMES_DIR)[0]
    sample_frames = len(os.listdir(os.path.join(FRAMES_DIR, sample_vid)))
    print(f"   üìÅ Sample: {sample_vid} has {sample_frames} frames")
else:
    print("   ‚ùå Frames directory missing!")
    all_ok = False

# Verify PKL files are valid
print(f"\nüî¨ Validating PKL Files:")
try:
    with open(os.path.join(ANN_DIR, 'person_bbox.pkl'), 'rb') as f:
        person_data = pickle.load(f)
    print(f"   ‚úÖ person_bbox.pkl: {len(person_data)} keyframes")
    
    with open(os.path.join(ANN_DIR, 'object_bbox_and_relationship.pkl'), 'rb') as f:
        obj_data = pickle.load(f)
    print(f"   ‚úÖ object_bbox_and_relationship.pkl: {len(obj_data)} entries")
except Exception as e:
    print(f"   ‚ùå Error reading PKL files: {e}")
    all_ok = False

if all_ok:
    print("\n" + "=" * 60)
    print("‚úÖ DATASET READY FOR TRAINING!")
    print("=" * 60)
else:
    print("\n" + "=" * 60)
    print("‚ö†Ô∏è DATASET INCOMPLETE - Check errors above")
    print("=" * 60)


## üöÄ Ready to Train!

**Model Architecture: `yowo_v2_x3d_m_yolo11m_multitask`**

| Component | Description |
|-----------|-------------|
| 2D Backbone | YOLO11m (pretrained on COCO) |
| 3D Backbone | X3D-M (pretrained on Kinetics-400) |
| Object Head | 36 classes (person + 35 objects) |
| Action Head | 157 Charades action classes |
| Relation Head | 26 relationship classes |
| Interaction Head | Binary (is object interacted with?) |

**Dataset: Charades + Action Genome**
- 288,782 annotated keyframes
- 9,601 videos
- Multi-task: Objects + Actions + Relationships

**Note:** Model checkpoints saved after each epoch to `/content/yowo/weights/charades_ag/`


In [None]:
# Cell 8: üöÄ TRAIN! (Fresh Start - Optimized Configuration)
# AMP (Automatic Mixed Precision) enabled for ~1.5-2x faster training!

import os
os.chdir('/content/yowo')

# Training configuration - OPTIMIZED
BATCH_SIZE = 160
ACCUMULATE = 2
MAX_EPOCHS = 13
LEARNING_RATE = 0.00035
LR_DECAY_EPOCHS = "7 9 11 12"
LEN_CLIP = 16
NUM_WORKERS = 4

# Build command
cmd = f"""python train.py \
    -d charades_ag \
    -v yowo_v2_x3d_m_yolo11m_multitask \
    --cuda \
    --amp \
    -bs {BATCH_SIZE} \
    -accu {ACCUMULATE} \
    --max_epoch {MAX_EPOCHS} \
    --lr_epoch {LR_DECAY_EPOCHS} \
    --root /content/yowo/data \
    -K {LEN_CLIP} \
    -lr {LEARNING_RATE} \
    --num_workers {NUM_WORKERS} \
    --save_folder /content/yowo/weights"""

print("=" * 70)
print("üöÄ FRESH START - OPTIMIZED TRAINING")
print("=" * 70)
print(f"üì¶ Batch size: {BATCH_SIZE} √ó {ACCUMULATE} = {BATCH_SIZE * ACCUMULATE} effective")
print(f"üìä Epochs: {MAX_EPOCHS}")
print(f"üìà Learning rate: {LEARNING_RATE}")
print(f"üìâ LR decay at epochs: {LR_DECAY_EPOCHS}")
print(f"")
print(f"   LR Schedule:")
print(f"   Epoch 1-6:   lr = 0.0003")
print(f"   Epoch 7-8:   lr = 0.00015")
print(f"   Epoch 9-10:  lr = 0.000075")
print(f"   Epoch 11-12: lr = 0.0000375")
print(f"   Epoch 13:    lr = 0.00001875")
print(f"")
print(f"üé¨ Clip length: {LEN_CLIP} frames")
print(f"‚ö° AMP: Enabled")
print(f"\nüìã Full command:\n{cmd}\n")
print("=" * 70 + "\n")

!{cmd}

In [None]:
# =============================================================================
# RESUME TRAINING FROM EPOCH 5 - PRODUCTION SCHEDULE
# =============================================================================
# This cell:
# 1. Fixes PyTorch 2.6 compatibility issue
# 2. Fixes missing optimizer state in checkpoint
# 3. Fixes LR scheduler to advance to correct epoch
# 4. Resumes from epoch 5 checkpoint with proper LR schedule
# =============================================================================

import os
os.chdir('/content/yowo')

# -----------------------------------------------------------------------------
# FIX 1: PyTorch 2.6 compatibility (add weights_only=False)
# FIX 2: Handle missing optimizer state in checkpoint
# -----------------------------------------------------------------------------
print("üîß Applying fixes...")

# Create fixed optimizer.py
optimizer_fix = '''import torch
from torch import optim


def build_optimizer(cfg, model, base_lr=0.0, resume=None):
    print('==============================')
    print('Optimizer: {}'.format(cfg['optimizer']))
    print('--momentum: {}'.format(cfg['momentum']))
    print('--weight_decay: {}'.format(cfg['weight_decay']))

    if cfg['optimizer'] == 'sgd':
        optimizer = optim.SGD(
            model.parameters(), 
            lr=base_lr,
            momentum=cfg['momentum'],
            weight_decay=cfg['weight_decay'])

    elif cfg['optimizer'] == 'adam':
        optimizer = optim.Adam(
            model.parameters(), 
            lr=base_lr,
            weight_decay=cfg['weight_decay'])
                                
    elif cfg['optimizer'] == 'adamw':
        optimizer = optim.AdamW(
            model.parameters(), 
            lr=base_lr,
            weight_decay=cfg['weight_decay'])
          
    start_epoch = 0
    if resume is not None:
        print('keep training: ', resume)
        checkpoint = torch.load(resume, weights_only=False)
        # Load optimizer state if available
        if "optimizer" in checkpoint:
            checkpoint_state_dict = checkpoint.pop("optimizer")
            optimizer.load_state_dict(checkpoint_state_dict)
            print('Loaded optimizer state from checkpoint')
        else:
            print('No optimizer state in checkpoint, using fresh optimizer')
        # Load epoch
        if "epoch" in checkpoint:
            start_epoch = checkpoint.pop("epoch") + 1
            print(f'Resuming from epoch {start_epoch}')
                        
    return optimizer, start_epoch
'''

with open('/content/yowo/utils/solver/optimizer.py', 'w') as f:
    f.write(optimizer_fix)
print("‚úÖ Fixed optimizer.py (handles missing optimizer state)")

# Fix build_multitask.py
with open('/content/yowo/models/yowo/build_multitask.py', 'r') as f:
    content = f.read()
content = content.replace(
    "torch.load(resume, map_location='cpu')",
    "torch.load(resume, map_location='cpu', weights_only=False)"
)
with open('/content/yowo/models/yowo/build_multitask.py', 'w') as f:
    f.write(content)
print("‚úÖ Fixed build_multitask.py (PyTorch 2.6 compatibility)")

# -----------------------------------------------------------------------------
# FIX 3: Add LR scheduler stepping to train.py so it advances to correct epoch
# -----------------------------------------------------------------------------
with open('/content/yowo/train.py', 'r') as f:
    train_content = f.read()

# Find and fix the LR scheduler initialization
old_scheduler_code = '''    # lr scheduler
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_epoch, args.lr_decay_ratio)'''

new_scheduler_code = '''    # lr scheduler
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_epoch, args.lr_decay_ratio)
    
    # Advance scheduler to match resumed epoch (CRITICAL for correct LR decay!)
    if start_epoch > 0:
        print(f'Advancing LR scheduler by {start_epoch} steps to match resumed epoch...')
        for _ in range(start_epoch):
            lr_scheduler.step()
        print(f'LR after advancing: {lr_scheduler.get_last_lr()[0]:.6f}')'''

if old_scheduler_code in train_content:
    train_content = train_content.replace(old_scheduler_code, new_scheduler_code)
    with open('/content/yowo/train.py', 'w') as f:
        f.write(train_content)
    print("‚úÖ Fixed train.py (LR scheduler advances to correct epoch)")
else:
    print("‚ö†Ô∏è Could not find scheduler code - may already be fixed or different format")

print("‚úÖ All fixes applied!")
print("")

# -----------------------------------------------------------------------------
# TRAINING CONFIGURATION
# -----------------------------------------------------------------------------
CHECKPOINT = "/content/yowo/weights/charades_ag/yowo_v2_x3d_m_yolo11m_multitask/yowo_v2_x3d_m_yolo11m_multitask_epoch_5.pth"
BATCH_SIZE = 160
ACCUMULATE = 2
MAX_EPOCHS = 13
LEARNING_RATE = 0.00035
LR_DECAY_EPOCHS = "7 9 11 12"
LEN_CLIP = 16
NUM_WORKERS = 4


cmd = f"""python train.py \
    -d charades_ag \
    -v yowo_v2_x3d_m_yolo11m_multitask \
    --cuda \
    --amp \
    -bs {BATCH_SIZE} \
    -accu {ACCUMULATE} \
    --max_epoch {MAX_EPOCHS} \
    --lr_epoch {LR_DECAY_EPOCHS} \
    --root /content/yowo/data \
    -K {LEN_CLIP} \
    -lr {LEARNING_RATE} \
    --num_workers {NUM_WORKERS} \
    --save_folder /content/yowo/weights \
    -r {CHECKPOINT}"""

print("=" * 70)
print("üöÄ RESUMING TRAINING FROM EPOCH 5")
print("=" * 70)
print(f"üì¶ Batch size: {BATCH_SIZE} (effective: {BATCH_SIZE * ACCUMULATE})")
print(f"üìä Max epochs: {MAX_EPOCHS}")
print(f"üìà Base LR: {LEARNING_RATE}")
print(f"üìâ LR decay at epochs: {LR_DECAY_EPOCHS}")
print(f"")
print(f"   ACTUAL LR Schedule (MultiStepLR behavior):")
print(f"   Epoch 6/13:  lr = 0.00035   (resuming here)")
print(f"   Epoch 7/13:  lr = 0.00035   (still full LR)")
print(f"   Epoch 8/13:  lr = 0.000175  (first decay after milestone 7)")
print(f"   Epoch 9/13:  lr = 0.000175")
print(f"   Epoch 10/13: lr = 0.0000875 (second decay after milestone 9)")
print(f"   Epoch 11/13: lr = 0.0000875")
print(f"   Epoch 12/13: lr = 0.00004375 (third decay after milestone 11)")
print(f"   Epoch 13/13: lr = 0.00002188 (fourth decay after milestone 12)")
print(f"")
print(f"üé¨ Clip length: {LEN_CLIP} frames")
print(f"‚ö° AMP: Enabled")
print(f"\nüìã Full command:\n{cmd}\n")
print("=" * 70)

!{cmd}

In [None]:
# Cell 9: Save Weights to Google Drive (after training)
import shutil, os

DRIVE_SAVE_PATH = "/content/drive/MyDrive/yooowo/weights"
os.makedirs(DRIVE_SAVE_PATH, exist_ok=True)

weights_dir = "/content/yowo/weights/charades_ag/yowo_v2_x3d_m_yolo11m_multitask"
if os.path.exists(weights_dir):
    for w in os.listdir(weights_dir):
        if w.endswith('.pth'):
            shutil.copy2(os.path.join(weights_dir, w), os.path.join(DRIVE_SAVE_PATH, w))
            print(f"‚úÖ Saved {w} to Drive")
else:
    print("‚ö†Ô∏è No weights found yet")


## üß™ Optional: Quick 1-Epoch Test

Run this first to verify everything works before full training:


In [None]:
# Quick test - run ~100 iterations to verify everything works
# Uses small batch to ensure it fits, includes AMP
# Uncomment the line below to run:

# !python train.py -d charades_ag -v yowo_v2_x3d_m_yolo11m_multitask --cuda --amp -bs 4 --max_epoch 1 --root /content/yowo/data -K 16 --num_workers 2 2>&1 | head -80

# If it works, you should see losses decreasing every 10 iterations.
# Then run Cell 8 for full training.


## üìà Resume Training from Checkpoint


In [None]:
# Resume from checkpoint (uncomment and modify path)
# CHECKPOINT = "/content/yowo/weights/charades_ag/yowo_v2_x3d_m_yolo11m_multitask/yowo_v2_x3d_m_yolo11m_multitask_epoch_5.pth"
# !python train.py -d charades_ag -v yowo_v2_x3d_m_yolo11m_multitask --cuda -bs {BATCH_SIZE} -accu {ACCUMULATE} --max_epoch 20 --root /content/yowo/data -K 16 -r {CHECKPOINT} --eval


## üîß Troubleshooting

| Problem | Solution |
|---------|----------|
| **OOM Error** | Reduce `BATCH_SIZE` by 2, increase `ACCUMULATE` proportionally (keep effective same) |
| **Training slow** | Increase batch size if GPU memory allows. L4/A100 can go higher. |
| **Loss not decreasing** | Try lr=0.0005 (higher) or lr=0.00005 (lower) |
| **`loss is NAN !!`** | Reduce learning rate to 0.00005, or check for bad data samples |
| **Loss stuck high** | Verify dataset extracted correctly, check annotations |
| **loss_act = 0.00** | This is NORMAL - some frames have no person, so no action loss |

## üìÅ Output Files

After training:
- **Weights**: `/content/yowo/weights/charades_ag/yowo_v2_x3d_m_yolo11m_multitask/`
- **Checkpoints**: `yowo_v2_x3d_m_yolo11m_multitask_epoch_N.pth`

**‚ö†Ô∏è IMPORTANT:** Run Cell 9 to copy weights to Google Drive before the runtime disconnects!
