In [None]:
import torch
import torch.nn as nn
import numpy as np
import os
from torch.cuda.amp import autocast

class Generator(nn.Module):
    def __init__(self, noise_dim=100):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(noise_dim + 3, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 3),
            nn.Sigmoid()
        )

    def forward(self, z, conditions):
        x = torch.cat([z, conditions], dim=1)
        return self.net(x)

def load_for_inference(checkpoint_path, device='cpu'):
    """Load model and normalization parameters from checkpoint"""
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # Initialize generator
    generator = Generator(noise_dim=100).to(device)
    generator.load_state_dict(checkpoint['generator_state_dict'])
    generator.eval()
    
    # Get normalization parameters
    data_min = checkpoint['data_min']
    data_max = checkpoint['data_max']
    
    return generator, (data_min, data_max)

def normalize(value, data_min, data_max):
    """Normalize using min-max scaling (matches training)"""
    return (value - data_min) / (data_max - data_min)

def denormalize(value, data_min, data_max):
    """Denormalize using min-max scaling"""
    return value * (data_max - data_min) + data_min

def generate(generator, norm_params, conditions, device='cpu'):
    """Generate predictions with optional conditioning"""
    data_min, data_max = norm_params
    input_norm = np.zeros(3)
    mask = np.ones(3, dtype=bool)
    
    # Normalize input conditions
    for i, key in enumerate(['Enrichment (%)', 'Flux', 'Burnup']):
        if conditions[key] is not None:
            input_norm[i] = normalize(conditions[key], data_min[i], data_max[i])
            mask[i] = False
    
    with torch.no_grad(), autocast():
        z = torch.randn(1, 100, device=device)
        conditions_tensor = torch.FloatTensor(input_norm).unsqueeze(0).to(device)
        output_norm = generator(z, conditions_tensor).cpu().numpy()[0]
    
    # Denormalize results
    results = {}
    for i, key in enumerate(['Enrichment (%)', 'Flux', 'Burnup']):
        if mask[i]:
            val = denormalize(output_norm[i], data_min[i], data_max[i])
            status = "(predicted)"
        else:
            val = conditions[key]
            status = "(input)"
        results[key] = (val, status)
    
    return results

def format_small_values(value):
    """Format very small values appropriately"""
    if isinstance(value, (float, np.floating)) and abs(value) < 1e-4:
        return f"{value:.10e}"
    return f"{value:.4f}"

if __name__ == "__main__":
    # Configuration
    checkpoint_dir = '/home/jovyan/FluxGAN/plots/checkpoint'
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    try:
        # Find latest checkpoint
        checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith('.tar')]
        latest = max(checkpoints, key=lambda x: int(x.split('_')[1].split('.')[0]))
        checkpoint_path = os.path.join(checkpoint_dir, latest)
        print(f"Loading {checkpoint_path}")
        
        # Load model
        generator, norm_params = load_for_inference(checkpoint_path, device)
        data_min, data_max = norm_params
        
        # Print normalization parameters for verification
        print("\nNormalization Parameters:")
        for i, key in enumerate(['Enrichment (%)', 'Flux', 'Burnup']):
            print(f"{key:>12}: Min={data_min[i]:.4f}, Max={data_max[i]:.4f}")
        
        # Interactive prediction loop
        while True:
            print("\nEnter conditions (leave blank to predict):")
            conditions = {}
            for key in ['Enrichment (%)', 'Flux', 'Burnup']:
                while True:
                    inp = input(f"{key}: ").strip()
                    if not inp:
                        conditions[key] = None
                        break
                    try:
                        conditions[key] = float(inp)
                        break
                    except ValueError:
                        print("Please enter a number or leave blank")
            
            if all(v is None for v in conditions.values()):
                print("At least one value required!")
                continue
                
            # Generate predictions
            results = generate(generator, norm_params, conditions, device)
            
            # Display results
            print("\nResults:")
            for key, (val, status) in results.items():
                formatted_val = format_small_values(val)
                print(f"{key:>12}: {formatted_val} {status}")
                
    except KeyboardInterrupt:
        print("\nExiting...")
    except Exception as e:
        print(f"Error: {str(e)}")

Loading /home/jovyan/FluxGAN/plots/checkpoint/checkpoint_1000.tar

Normalization Parameters:
Enrichment (%): Min=1.0007, Max=89.9858
        Flux: Min=6.4532, Max=11.9977
      Burnup: Min=0.0000, Max=0.0000

Enter conditions (leave blank to predict):


Enrichment (%):  12
Flux:  
Burnup:  



Results:
Enrichment (%): 12.0000 (input)
        Flux: 11.9977 (predicted)
      Burnup: 1.4858083577e-08 (predicted)

Enter conditions (leave blank to predict):


In [23]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import os
from IPython.display import display

# Corrected Generator Model
class Generator(nn.Module):
    def __init__(self, noise_dim=100):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(noise_dim + 3, 256),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(256),  # BatchNorm added here to match checkpoint
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(128),  # BatchNorm added here to match checkpoint
            nn.Linear(128, 3),
            nn.Sigmoid()  # Sigmoid for output
        )

    def forward(self, z, conditions):
        x = torch.cat([z, conditions], dim=1)
        return self.net(x)

# Function to load the checkpoint
def load_checkpoint(checkpoint_path):
    """Load checkpoint with robust error handling"""
    try:
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        
        # Debug: show checkpoint contents
        print("Checkpoint keys found:", list(checkpoint.keys()))
        
        # Initialize generator
        generator = Generator(noise_dim=100)
        
        # Try different possible state dict keys
        state_dict_keys = ['generator_state_dict', 'generator', 'model_state_dict', 'state_dict']
        loaded = False
        for key in state_dict_keys:
            if key in checkpoint:
                # Load the state dict with strict=False to ignore missing keys
                generator.load_state_dict(checkpoint[key], strict=False)
                loaded = True
                print(f"Loaded generator weights from key: '{key}'")
                break
        
        if not loaded:
            # Try loading direct state dict
            try:
                generator.load_state_dict(checkpoint, strict=False)
                loaded = True
                print("Loaded generator weights directly from checkpoint")
            except:
                pass
        
        if not loaded:
            raise KeyError("No generator state dict found in checkpoint")
        
        generator.eval()
        
        # Get normalization parameters with defaults
        data_min = checkpoint.get('data_min', [0, 0, 0])
        data_max = checkpoint.get('data_max', [1, 1, 1])
        
        return generator, {
            'data_min': data_min,
            'data_max': data_max,
            'scaler': {'feature_range': (0, 1)}  # Assuming MinMaxScaler
        }
        
    except Exception as e:
        raise ValueError(f"Error loading checkpoint: {str(e)}")

# Function to generate flux and burnup for given enrichment values
def generate_flux_burnup(generator, checkpoint_info, enrichment_values):
    """Generate flux and burnup for given enrichment values"""
    data_min = checkpoint_info['data_min']
    data_max = checkpoint_info['data_max']
    feature_range = checkpoint_info['scaler']['feature_range']
    
    results = []
    
    for enrich in enrichment_values:
        # Normalize enrichment
        if data_max[0] - data_min[0] == 0:  # Prevent division by zero
            norm_enrich = 0.5
        else:
            norm_enrich = ((enrich - data_min[0]) / 
                          (data_max[0] - data_min[0])) * (feature_range[1] - feature_range[0]) + feature_range[0]
        
        # Create input with enrichment specified, others zero
        input_norm = np.array([norm_enrich, 0, 0])
        
        with torch.no_grad():
            z = torch.randn(1, 100)
            conditions_tensor = torch.FloatTensor(input_norm).unsqueeze(0)
            output_norm = generator(z, conditions_tensor).numpy()[0]
        
        # Denormalize outputs
        def denormalize(val, idx):
            return ((val - feature_range[0]) / 
                   (feature_range[1] - feature_range[0])) * (data_max[idx] - data_min[idx]) + data_min[idx]
        
        flux = denormalize(output_norm[1], 1)
        burnup = denormalize(output_norm[2], 2)
        
        results.append([enrich, flux, burnup])
    
    return pd.DataFrame(results, columns=['Enrichment (%)', 'Flux', 'Burnup'])

# Function to find the latest checkpoint in the directory
def find_latest_checkpoint(checkpoint_dir):
    """Find the latest checkpoint with proper error handling"""
    if not os.path.exists(checkpoint_dir):
        raise FileNotFoundError(f"Checkpoint directory not found: {checkpoint_dir}")
    
    checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith('.tar')]
    if not checkpoints:
        raise FileNotFoundError(f"No .tar checkpoint files found in {checkpoint_dir}")
    
    # Sort by epoch number
    def extract_epoch(f):
        try:
            return int(f.split('_')[-1].split('.')[0])
        except:
            return -1
    
    checkpoints.sort(key=extract_epoch)
    latest = checkpoints[-1]
    return os.path.join(checkpoint_dir, latest)

# Main execution
if __name__ == "__main__":
    try:
        # Configuration
        checkpoint_dir = '/home/jovyan/FluxGAN/plots/checkpoint'
        output_csv = os.path.join(checkpoint_dir, 'generated_flux_burnup.csv')
        
        # Define enrichment values
        enrichment_values = np.array([
            11.49, 3.5, 67.72, 17.36, 11.02, 69.62, 1.73, 42.41, 36.69, 18.87, 
            67.97, 63.78, 55.78, 20.99, 67.82, 60.98, 12.33, 56.87, 62.92, 5.0,
            10.35, 2.11, 64.83, 4.93, 26.02, 76.18, 61.37, 62.78, 35.71, 35.29,
            74.78, 48.16, 36.33, 62.18, 41.57, 32.45, 44.28, 29.37, 59.88, 31.42,
            61.48, 43.76, 60.84, 29.65, 50.77, 14.72, 56.37, 18.69, 59.08, 52.48,
            69.12, 2.92, 42.22, 62.93, 75.02, 14.64, 34.94, 9.09, 67.41, 49.67,
            58.28, 49.66, 45.65, 30.15, 71.5, 52.89, 15.55, 29.94, 47.47, 78.25,
            25.37, 3.81, 34.87, 20.77, 40.62, 63.6, 35.78, 47.05, 23.73, 51.87,
            29.0, 63.33, 66.46, 20.4, 14.71, 11.55, 57.87, 55.93, 7.57, 61.82,
            62.88, 70.7, 36.32, 75.0, 39.93, 54.81, 74.23, 24.26, 23.35, 79.74
        ])
        
        # Find and load checkpoint
        checkpoint_path = find_latest_checkpoint(checkpoint_dir)
        print(f"\nLoading checkpoint from: {checkpoint_path}")
        
        generator, checkpoint_info = load_checkpoint(checkpoint_path)
        print("\nGenerator successfully loaded!")
        print(f"Data min: {checkpoint_info['data_min']}")
        print(f"Data max: {checkpoint_info['data_max']}")
        
        # Generate predictions
        print("\nGenerating flux and burnup values...")
        results_df = generate_flux_burnup(generator, checkpoint_info, enrichment_values)
        
        # Save and display results
        results_df.to_csv(output_csv, index=False)
        print(f"\nResults saved to: {output_csv}")
        print("\nFirst 5 predictions:")
        display(results_df.head())
        
    except Exception as e:
        print(f"\nError occurred: {str(e)}")
        print("\nTroubleshooting steps:")
        print("1. Verify the checkpoint directory exists and contains .tar files")
        print("2. Check the checkpoint contents with:")
        print(f"   import torch; print(torch.load('{checkpoint_path}', map_location='cpu').keys())")
        print("3. Ensure the Generator architecture matches your training code")



Loading checkpoint from: /home/jovyan/FluxGAN/plots/checkpoint/checkpoint_10.tar
Checkpoint keys found: ['epoch', 'generator_state_dict', 'discriminator_state_dict', 'optimizer_G_state_dict', 'optimizer_D_state_dict', 'data_min', 'data_max']
Loaded generator weights from key: 'generator_state_dict'

Generator successfully loaded!
Data min: [1.00073484e+00 6.45316091e+00 1.48580836e-08]
Data max: [8.99858114e+01 1.19977241e+01 4.58969245e-08]

Generating flux and burnup values...

Results saved to: /home/jovyan/FluxGAN/plots/checkpoint/generated_flux_burnup.csv

First 5 predictions:


Unnamed: 0,Enrichment (%),Flux,Burnup
0,11.49,9.186839,3.255059e-08
1,3.5,10.010794,2.756544e-08
2,67.72,8.791043,2.280178e-08
3,17.36,9.370809,2.538571e-08
4,11.02,9.563942,3.368628e-08
