In [1]:
import os
import torch
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader
from dataloader_rec import ClimateReconstructionDataset
from generative_world_model import Generative_World_Model

# Configuration
backbone = 'beamvq_reconstruction_v1'
DATA_DIR = "/jizhicfs/easyluwu/scaling_law/ft_local/low_res"
CHECKPOINT_PATH = f'./checkpoints/{backbone}_best_model.pth'
RESULT_DIR = './results'
BATCH_SIZE = 3
VARIABLES = range(69)  

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load model
model = BeamVQ(
    in_channel=69,
    res_layers=2,
    embedding_nums=1024, 
    embedding_dim=256,
    top_k=10).to(device)

# Load checkpoint with proper handling of DDP prefix
if os.path.exists(CHECKPOINT_PATH):
    # Load the state dict
    state_dict = torch.load(CHECKPOINT_PATH, map_location=device, weights_only=True)
    
    # Remove 'module.' prefix from keys if present
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith('module.'):
            new_state_dict[k[7:]] = v  # remove 'module.' prefix
        else:
            new_state_dict[k] = v
    
    model.load_state_dict(new_state_dict)
    print(f"Loaded model from {CHECKPOINT_PATH}")
else:
    raise FileNotFoundError(f"Checkpoint not found at {CHECKPOINT_PATH}")

model.eval()

# Create test dataset and loader
test_dataset = ClimateReconstructionDataset(
    data_path=DATA_DIR,
    years=range(2019, 2022),  # Using test years
    variables=VARIABLES
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4,
    pin_memory=True if torch.cuda.is_available() else False
)

# Inference function
def run_inference(model, test_loader, device):
    all_inputs = []
    all_targets = []
    all_outputs = []
    all_vq_losses = []
    
    with torch.no_grad():
        for inputs, targets in tqdm(test_loader, desc="Running inference"):
            inputs = inputs.to(device)
            targets = targets.to(device)
            
            pred, _, vq_loss = model(inputs)
            
            # Collect results
            all_inputs.append(inputs.cpu().numpy())
            all_targets.append(targets.cpu().numpy())
            all_outputs.append(pred.cpu().numpy())
            all_vq_losses.append(vq_loss.item())
    
    # Concatenate all batches
    all_inputs = np.concatenate(all_inputs, axis=0)
    all_targets = np.concatenate(all_targets, axis=0)
    all_outputs = np.concatenate(all_outputs, axis=0)
    avg_vq_loss = np.mean(all_vq_losses)
    
    return all_inputs, all_targets, all_outputs, avg_vq_loss

# Run inference
print("Starting inference...")
inputs, targets, outputs, avg_vq_loss = run_inference(model, test_loader, device)

# Save results
os.makedirs(RESULT_DIR, exist_ok=True)
np.save(f'{RESULT_DIR}/{backbone}_inputs.npy', inputs)
np.save(f'{RESULT_DIR}/{backbone}_targets.npy', targets)
np.save(f'{RESULT_DIR}/{backbone}_outputs.npy', outputs)

print(f"Inference completed. Results saved to {RESULT_DIR}")
print(f"Average VQ Loss: {avg_vq_loss:.14f}")

# Calculate MSE
mse = np.mean((targets - outputs) ** 2)
print(f"Reconstruction MSE: {mse:.14f}")

Loaded model from ./checkpoints/beamvq_reconstruction_v1_best_model.pth
Starting inference...


Running inference: 100%|█████████████████████████████████████████████████████████████████████████| 366/366 [02:34<00:00,  2.36it/s]


Inference completed. Results saved to ./results
Average VQ Loss: 0.00000041711566
Reconstruction MSE: 0.00001372790211
