In [None]:
from train import fit
from model import MovieHeteroGAT
import torch
import os
import gc
from dotenv import load_dotenv
from tqdm.notebook import tqdm

# Enable tqdm for notebook
import sys
sys.path.insert(0, '..')

# Load environment variables from .env file
load_dotenv()

# Check CUDA availability and clean up
print("Checking GPU status...")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    try:
        torch.cuda.empty_cache()
        gc.collect()
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"Memory allocated: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB")
        print(f"Memory reserved: {torch.cuda.memory_reserved(0) / 1e9:.2f} GB")
        
        # Test CUDA by creating a small tensor
        test_tensor = torch.tensor([1.0]).cuda()
        del test_tensor
        torch.cuda.synchronize()
        print("✓ CUDA is working")
        device = torch.device('cuda')
    except Exception as e:
        print(f"⚠ CUDA error detected: {e}")
        print("Falling back to CPU...")
        device = torch.device('cpu')
else:
    print("Using CPU")
    device = torch.device('cpu')

print("\nLoading data...")
# 1. Load Data
data_path = os.getenv("DATASET_PATH")
if data_path is None:
    raise ValueError("DATASET_PATH environment variable not set. Please check your .env file.")

data = torch.load(data_path, weights_only=False)
print(f"✓ Data loaded: {data['user'].num_nodes:,} users, {data['movie'].num_nodes:,} movies")

# 2. Init Model
print("\nInitializing model...")
model = MovieHeteroGAT(
    metadata=data.metadata(),
    num_users=data['user'].num_nodes,
    num_movies=data['movie'].num_nodes,
    hidden_channels=256,
    out_channels=512,
    num_layers=3,
    num_heads=4
)

print(f"✓ Model initialized: {model.model_details}")
print(f"✓ Using device: {device}")

# 3. Train
print("\nStarting training...")
print("=" * 60)
fit(
    model, 
    data, 
    epochs=20, 
    batch_size=4096 if device.type == 'cuda' else 512,  # Smaller batch for CPU
    experiment_name="MyFirstRun",
    device=device  # Explicitly pass the device
)
print("=" * 60)
print("✓ Training complete!")

  from .autonotebook import tqdm as notebook_tqdm


Checking GPU status...
CUDA available: True
GPU: NVIDIA GeForce RTX 3090
Memory allocated: 0.00 GB
Memory reserved: 0.00 GB
✓ CUDA is working

Loading data...
✓ Data loaded: 11,061 users, 142,374 movies

Initializing model...
✓ Model initialized: HeteroGATv2_L2_H2_Emb64_Hidden64_Out64_kaiming_DecoderMLP
✓ Using device: cuda

Starting training...
Initializing Graph Loaders... (This maps the topology)
Rating range in data: [0.50, 5.00]
Split sizes - Train: 14,234,477, Val: 1,779,310, Test: 1,779,310
Training on 3476 batches. Validation on 435 batches. Test on 435 batches.

Epoch 1/20


Train Epoch 1: 100%|██████████| 3476/3476 [01:16<00:00, 45.70it/s, loss=1.1019, rmse=1.0497]
                                                             


End of Epoch 1
Time: 00:01:21
Train RMSE: 1.0915 | Val RMSE: 1.0433
⭐ New Best Model Saved! (RMSE: 1.0433)

Epoch 2/20


Train Epoch 2: 100%|██████████| 3476/3476 [01:16<00:00, 45.28it/s, loss=0.9998, rmse=0.9999]
                                                             


End of Epoch 2
Time: 00:01:22
Train RMSE: 1.0428 | Val RMSE: 1.0430
⭐ New Best Model Saved! (RMSE: 1.0430)

Epoch 3/20


Train Epoch 3: 100%|██████████| 3476/3476 [01:16<00:00, 45.59it/s, loss=1.1386, rmse=1.0670]
                                                             


End of Epoch 3
Time: 00:01:21
Train RMSE: 1.0428 | Val RMSE: 1.0435

Epoch 4/20


Train Epoch 4: 100%|██████████| 3476/3476 [01:16<00:00, 45.55it/s, loss=1.1898, rmse=1.0908]
                                                             


End of Epoch 4
Time: 00:01:22
Train RMSE: 1.0427 | Val RMSE: 1.0430

Epoch 5/20


Train Epoch 5: 100%|██████████| 3476/3476 [01:16<00:00, 45.53it/s, loss=0.9927, rmse=0.9964]
                                                             


End of Epoch 5
Time: 00:01:22
Train RMSE: 1.0427 | Val RMSE: 1.0430

Epoch 6/20


Train Epoch 6: 100%|██████████| 3476/3476 [01:16<00:00, 45.54it/s, loss=1.0536, rmse=1.0265]
                                                             


End of Epoch 6
Time: 00:01:21
Train RMSE: 1.0427 | Val RMSE: 1.0430

Epoch 7/20


Train Epoch 7: 100%|██████████| 3476/3476 [01:16<00:00, 45.57it/s, loss=1.0546, rmse=1.0270]
                                                             


End of Epoch 7
Time: 00:01:21
Train RMSE: 1.0427 | Val RMSE: 1.0430
⭐ New Best Model Saved! (RMSE: 1.0430)

Epoch 8/20


Train Epoch 8: 100%|██████████| 3476/3476 [01:15<00:00, 45.82it/s, loss=0.9601, rmse=0.9798]
                                                             


End of Epoch 8
Time: 00:01:21
Train RMSE: 1.0427 | Val RMSE: 1.0431

Epoch 9/20


Train Epoch 9: 100%|██████████| 3476/3476 [01:16<00:00, 45.26it/s, loss=1.0777, rmse=1.0381]
                                                             


End of Epoch 9
Time: 00:01:22
Train RMSE: 1.0427 | Val RMSE: 1.0430

Epoch 10/20


Train Epoch 10:  96%|█████████▌| 3339/3476 [01:12<00:02, 45.76it/s, loss=1.0626, rmse=1.0308]


KeyboardInterrupt: 