# üéØ CNN-Based Telugu Poem Interpretation - Training

**Train the CNN+RNN model for Telugu poem analysis and interpretation**

## Step 1: Setup & Clone Project

In [None]:
# Clone from GitHub
!git clone https://github.com/maneendra03/CNN-Based-Telugu-Poem-Analysis-inspired-by-human-rote-learning.git /content/project
%cd /content/project

In [None]:
# Install dependencies
!pip install torch torchvision torchaudio
!pip install tqdm pyyaml

In [None]:
# Check GPU
import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## Step 2: Load Telugu Dataset

In [None]:
import json
import sys
sys.path.insert(0, '/content/project')

from src.data.data_loader import PoemDataLoader
from src.preprocessing.tokenizer import PoemTokenizer

# Load poems
with open('data/processed/telugu_poems.json', 'r', encoding='utf-8') as f:
    poems = json.load(f)

print(f"‚úÖ Loaded {len(poems)} Telugu poems")
print(f"Sample: {poems[0]['text'][:100]}...")

## Step 3: Create DataLoader

In [None]:
from torch.utils.data import Dataset, DataLoader
from src.preprocessing.telugu_cleaner import TeluguTextCleaner

class TeluguPoemDataset(Dataset):
    def __init__(self, poems, tokenizer, max_length=100):
        self.poems = poems
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.cleaner = TeluguTextCleaner()
    
    def __len__(self):
        return len(self.poems)
    
    def __getitem__(self, idx):
        poem = self.poems[idx]
        text = poem.get('text', '') if isinstance(poem, dict) else poem
        text = self.cleaner.clean(text)
        
        # Encode
        tokens = self.tokenizer.encode(text)
        
        # Pad/truncate
        if len(tokens) > self.max_length:
            tokens = tokens[:self.max_length]
        else:
            tokens = tokens + [0] * (self.max_length - len(tokens))
        
        input_ids = torch.tensor(tokens[:-1], dtype=torch.long)
        target_ids = torch.tensor(tokens[1:], dtype=torch.long)
        
        return {
            'input_ids': input_ids,
            'target_ids': target_ids
        }

# Create tokenizer
tokenizer = PoemTokenizer(min_freq=1)
tokenizer.fit([p['text'] for p in poems])

print(f"‚úÖ Tokenizer: vocab_size={tokenizer.word_vocab_size}")

# Create datasets
from sklearn.model_selection import train_test_split

train_poems, test_poems = train_test_split(poems, test_size=0.2, random_state=42)
train_poems, val_poems = train_test_split(train_poems, test_size=0.1, random_state=42)

train_dataset = TeluguPoemDataset(train_poems, tokenizer, max_length=100)
val_dataset = TeluguPoemDataset(val_poems, tokenizer, max_length=100)
test_dataset = TeluguPoemDataset(test_poems, tokenizer, max_length=100)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=2)

print(f"‚úÖ Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")

## Step 4: Initialize CNN Interpretation Model

In [None]:
from src.models.poem_learner import PoemLearner

# Create CNN-based interpretation model
model = PoemLearner(
    vocab_size=tokenizer.word_vocab_size,
    embedding_dim=256,
    hidden_dim=512
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Count parameters
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"‚úÖ CNN Interpretation Model")
print(f"   Total params: {total:,}")
print(f"   Trainable: {trainable:,}")

## Step 5: Train CNN Model

In [None]:
from torch.optim import Adam
from tqdm import tqdm
from pathlib import Path

# Training config
CONFIG = {
    'epochs': 50,
    'learning_rate': 1e-3,
    'save_every': 5
}

optimizer = Adam(model.parameters(), lr=CONFIG['learning_rate'])

# Create checkpoint dir
checkpoint_dir = Path('/content/project/checkpoints/interpretation')
checkpoint_dir.mkdir(parents=True, exist_ok=True)

print("üöÄ Training CNN Interpretation Model...")
print(f"   Epochs: {CONFIG['epochs']}")
print(f"   Batches: {len(train_loader)}")

best_loss = float('inf')

for epoch in range(CONFIG['epochs']):
    model.train()
    epoch_loss = 0
    
    progress = tqdm(train_loader, desc=f"Epoch {epoch+1}/{CONFIG['epochs']}")
    for batch in progress:
        input_ids = batch['input_ids'].to(device)
        target_ids = batch['target_ids'].to(device)
        
        # Forward pass
        outputs = model(input_ids, target_ids)
        
        # Loss is in the output dict
        if 'loss' in outputs and outputs['loss'] is not None:
            loss = outputs['loss']
        else:
            # Calculate loss manually
            from torch.nn import CrossEntropyLoss
            loss_fn = CrossEntropyLoss(ignore_index=0)
            loss = loss_fn(outputs['logits'].view(-1, outputs['logits'].size(-1)), target_ids.view(-1))
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        epoch_loss += loss.item()
        progress.set_postfix({'loss': f'{loss.item():.4f}'})
    
    avg_loss = epoch_loss / len(train_loader)
    print(f"\nüìä Epoch {epoch+1} | Loss: {avg_loss:.4f}")
    
    # Save checkpoint
    if (epoch + 1) % CONFIG['save_every'] == 0 or avg_loss < best_loss:
        if avg_loss < best_loss:
            best_loss = avg_loss
            save_path = checkpoint_dir / 'best_cnn_interpretation.pt'
        else:
            save_path = checkpoint_dir / f'cnn_epoch_{epoch+1}.pt'
        
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'epoch': epoch,
            'loss': avg_loss,
            'vocab_size': tokenizer.word_vocab_size
        }, save_path)
        print(f"üíæ Saved: {save_path}")

print("\n‚úÖ Training Complete!")

## Step 6: Test Interpretation

In [None]:
# Test poem interpretation
model.eval()

test_poems = [
    "‡∞ö‡∞Ç‡∞¶‡∞Æ‡∞æ‡∞Æ ‡∞∞‡∞æ‡∞µ‡±á ‡∞ö‡∞æ‡∞≤ ‡∞¨‡∞æ‡∞ó‡±Å‡∞Ç‡∞¶‡±á",
    "‡∞§‡±Ü‡∞≤‡±Å‡∞ó‡±Å ‡∞≠‡∞æ‡∞∑ ‡∞Æ‡∞ß‡±Å‡∞∞‡∞Æ‡±à‡∞®‡∞¶‡∞ø",
    "‡∞Ö‡∞Æ‡±ç‡∞Æ ‡∞™‡±ç‡∞∞‡±á‡∞Æ ‡∞Ö‡∞Æ‡±É‡∞§‡∞Æ‡∞Ø‡∞Ç"
]

print("üìù Telugu Poem Interpretation Test")
print("=" * 50)

for poem_text in test_poems:
    print(f"\nüîπ Poem: {poem_text}")
    
    # Encode
    tokens = tokenizer.encode(poem_text)
    if len(tokens) > 50:
        tokens = tokens[:50]
    else:
        tokens = tokens + [0] * (50 - len(tokens))
    
    input_ids = torch.tensor([tokens[:-1]]).to(device)
    target_ids = torch.tensor([tokens[1:]]).to(device)
    
    with torch.no_grad():
        output = model(input_ids, target_ids)
    
    print(f"   ‚úÖ Interpreted successfully")
    print(f"   Features: {output['poem_representation'].shape}")
    print(f"   Logits: {output['logits'].shape}")

## Step 7: Save Final Model

In [None]:
# Save final model
final_path = checkpoint_dir / 'final_cnn_interpretation.pt'

torch.save({
    'model_state_dict': model.state_dict(),
    'config': CONFIG,
    'vocab_size': tokenizer.word_vocab_size,
    'best_loss': best_loss
}, final_path)

print(f"‚úÖ Final model saved: {final_path}")
print(f"\nüìÅ Download from: /content/project/checkpoints/interpretation/")