In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataset import create_wall_dataloader
import wandb

class PatchEmbed(nn.Module):
    def __init__(self, img_size=64, patch_size=8, in_chans=2, embed_dim=256):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        x = self.proj(x)  # B, E, H/P, W/P
        x = x.flatten(2).transpose(1, 2)  # B, N, E
        x = self.norm(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads=8, mlp_ratio=4., drop=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, num_heads, dropout=drop, batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim * mlp_ratio)),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(int(dim * mlp_ratio), dim),
            nn.Dropout(drop)
        )

    def forward(self, x):
        x = x + self.attn(*[self.norm1(x)]*3)[0]
        x = x + self.mlp(self.norm2(x))
        return x

class ViTEncoder(nn.Module):
    def __init__(self, img_size=64, patch_size=8, in_chans=2, embed_dim=256):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        num_patches = (img_size // patch_size) ** 2
        
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim) for _ in range(6)
        ])
        
        self.norm = nn.LayerNorm(embed_dim)
        
        # Initialize weights
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)

    def forward(self, x):
        x = self.patch_embed(x)  # B, N, E
        
        # Add cls token and pos embed
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat([cls_token, x], dim=1)
        x = x + self.pos_embed
        
        # Apply transformer blocks
        for block in self.blocks:
            x = block(x)
            
        x = self.norm(x)
        return x[:, 0]  # Return CLS token only

class EncoderTrainer(nn.Module):
    def __init__(self, device="cuda", embed_dim=256):
        super().__init__()
        self.encoder = ViTEncoder(embed_dim=embed_dim)
        self.target_encoder = ViTEncoder(embed_dim=embed_dim)
        
        # Freeze target encoder
        for param in self.target_encoder.parameters():
            param.requires_grad = False
            
        # Initialize target encoder
        self.momentum_update(m=0.0)
        
        self.predictor = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.LayerNorm(embed_dim),
            nn.GELU(),
            nn.Linear(embed_dim, embed_dim)
        )
        
    @torch.no_grad()
    def momentum_update(self, m=0.996):
        for param_q, param_k in zip(self.encoder.parameters(),
                                  self.target_encoder.parameters()):
            param_k.data = param_k.data * m + param_q.data * (1. - m)
    
    def forward(self, x):
        z = self.encoder(x)
        p = self.predictor(z)
        return p
    
    def compute_loss(self, states, debug=False):
        # Get online predictions
        z1 = self.encoder(states[:, 0])  # First frame
        z2 = self.encoder(states[:, 1])  # Second frame
        p1 = self.predictor(z1)
        p2 = self.predictor(z2)
        
        # Get target projections
        with torch.no_grad():
            t1 = self.target_encoder(states[:, 0])
            t2 = self.target_encoder(states[:, 1])
        
        # Normalize
        p1 = F.normalize(p1, dim=-1)
        p2 = F.normalize(p2, dim=-1)
        t1 = F.normalize(t1, dim=-1)
        t2 = F.normalize(t2, dim=-1)
        
        # BYOL-style loss
        loss = 2 - 2 * (
            (p1 * t2).sum(dim=-1).mean() +
            (p2 * t1).sum(dim=-1).mean()
        )
        
        if debug:
            with torch.no_grad():
                print(f"\nLoss: {loss.item():.4f}")
                print(f"Prediction norm: {p1.norm(dim=1).mean():.4f}")
                print(f"Target norm: {t1.norm(dim=1).mean():.4f}")
                print(f"Cosine sim: {F.cosine_similarity(p1, t2).mean():.4f}")
        
        return loss

def train_encoder():
    # Initialize wandb
    wandb.init(project="vit-encoder-training", config={
        "learning_rate": 1e-4,
        "weight_decay": 0.05,
        "batch_size": 32,
        "epochs": 100,
        "scheduler": "CosineAnnealingLR"
    })

    # Create dataloader using the provided function
    train_loader = create_wall_dataloader(
        data_path="/scratch/DL24FA/train",
        probing=False,
        device="cuda",
        batch_size=wandb.config.batch_size,
        train=True
    )
    
    # Initialize model and optimizer
    model = EncoderTrainer().cuda()
    optimizer = torch.optim.AdamW(model.parameters(), lr=wandb.config.learning_rate, weight_decay=wandb.config.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=wandb.config.epochs)
    
    # Log the model architecture
    wandb.watch(model.encoder, log="all", log_freq=10)

    # Training loop
    for epoch in range(wandb.config.epochs):
        total_loss = 0
        num_batches = 0
        
        for batch in train_loader:
            optimizer.zero_grad()
            
            # Compute loss
            loss = model.compute_loss(batch.states, debug=(epoch % 10 == 0))
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            
            optimizer.step()
            model.momentum_update()
            
            total_loss += loss.item()
            num_batches += 1
            
            # Log batch loss
            wandb.log({"batch_loss": loss.item()})
        
        # Print and log epoch statistics
        avg_loss = total_loss / num_batches
        print(f"Epoch {epoch}, Average Loss: {avg_loss:.4f}")
        wandb.log({
            "epoch": epoch,
            "avg_loss": avg_loss,
            "learning_rate": scheduler.get_last_lr()[0]
        })
        scheduler.step()
    
    # Save the trained encoder
    torch.save(model.encoder.state_dict(), 'encoder.pth')
    wandb.save('encoder.pth')
    
    # Close wandb run
    wandb.finish()
    
    return model.encoder

if __name__ == "__main__":
    encoder = train_encoder()