In [1]:
# First install CUDA-compatible dependencies
!pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.1.0+cu121.html

# Then install torch-geometric
!pip install torch_geometric 

# Verify installation
import torch_geometric
print(f"Success! PyG version: {torch_geometric.__version__}")

Looking in links: https://data.pyg.org/whl/torch-2.1.0+cu121.html
Collecting pyg_lib
  Downloading https://data.pyg.org/whl/torch-2.1.0%2Bcu121/pyg_lib-0.4.0%2Bpt21cu121-cp310-cp310-linux_x86_64.whl (2.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m27.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch_scatter
  Downloading https://data.pyg.org/whl/torch-2.1.0%2Bcu121/torch_scatter-2.1.2%2Bpt21cu121-cp310-cp310-linux_x86_64.whl (10.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.8/10.8 MB[0m [31m34.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch_sparse
  Downloading https://data.pyg.org/whl/torch-2.1.0%2Bcu121/torch_sparse-0.6.18%2Bpt21cu121-cp310-cp310-linux_x86_64.whl (5.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.0/5.0 MB[0m [31m55.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch_cluster
  Downloading https://data.pyg.org/whl/torch-2.1.0%2Bcu121/torch_cluste



Success! PyG version: 2.6.1


In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torch_geometric.data import Data
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
from torch_geometric.data import Data, DataLoader
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from tqdm import tqdm
import random

# Set random seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True

set_seed(42)

In [3]:
MAX_SEQ_LEN = 1024

In [4]:
# Define paths to data files
TRAIN_SEQUENCES_PATH = "/kaggle/input/stanford-rna-3d-folding/train_sequences.csv"
TRAIN_LABELS_PATH = "/kaggle/input/stanford-rna-3d-folding/train_labels.csv"
# Load data
train_sequences = pd.read_csv(TRAIN_SEQUENCES_PATH)
train_labels = pd.read_csv(TRAIN_LABELS_PATH)

print(f"Loaded {len(train_sequences)} RNA sequences and {len(train_labels)} nucleotide labels")

Loaded 844 RNA sequences and 137095 nucleotide labels


In [5]:
# Preprocess data
# 1. Encoding nucleotides
nucleotide_mapping = {'A': 0, 'C': 1, 'G': 2, 'U': 3}
reverse_mapping = {0: 'A', 1: 'C', 2: 'G', 3: 'U'}

In [6]:
# 2. Create feature representation for each nucleotide
def one_hot_encode(nucleotide):
    encoding = [0, 0, 0, 0]
    if nucleotide in nucleotide_mapping:
        encoding[nucleotide_mapping[nucleotide]] = 1
    return encoding

In [7]:
# Function to create a graph from an RNA sequence
def sequence_to_graph(sequence, target_id, labels_df=None, max_connections=MAX_SEQ_LEN):
    """
    Create a graph representation of an RNA sequence.
    
    Args:
        sequence: The RNA sequence
        target_id: Identifier for the RNA
        labels_df: Optional dataframe with 3D coordinate labels
        max_connections: Maximum number of edges to create (to avoid CUDA OOM errors)
        
    Returns:
        PyTorch Geometric Data object
    """
    # One-hot encode each nucleotide
    x = [one_hot_encode(nt) for nt in sequence]
    x = torch.tensor(x, dtype=torch.float)
    
    # Create edges - connect adjacent nucleotides (backbone)
    # and potentially other connections based on domain knowledge
    edges = []
    
    # Always add backbone connections
    for i in range(len(sequence) - 1):
        # Connect to next nucleotide (backbone)
        edges.append([i, i + 1])
        edges.append([i + 1, i])  # Bidirectional
    
    # Add potential base-pairing connections, but limit total edges to avoid OOM
    edge_count = len(edges)
    max_additional_edges = max_connections - edge_count
    
    if max_additional_edges > 0:
        potential_base_pairs = []
        
        # Identify potential base pairs (A-U, G-C)
        for i in range(len(sequence)):
            for j in range(i + 3, len(sequence)):  # Minimum loop size of 3
                if (sequence[i] == 'A' and sequence[j] == 'U') or \
                   (sequence[i] == 'U' and sequence[j] == 'A') or \
                   (sequence[i] == 'G' and sequence[j] == 'C') or \
                   (sequence[i] == 'C' and sequence[j] == 'G'):
                    # Store the potential base pair
                    potential_base_pairs.append((i, j))
        
        # Randomly select base pairs if we have too many
        if len(potential_base_pairs) > max_additional_edges // 2:  # Divide by 2 for bidirectional edges
            # Shuffle and take only what we can handle
            random.shuffle(potential_base_pairs)
            potential_base_pairs = potential_base_pairs[:max_additional_edges // 2]
        
        # Add the selected base pairs
        for i, j in potential_base_pairs:
            edges.append([i, j])
            edges.append([j, i])  # Bidirectional
    
    # Convert edges to tensor
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    
    # Get coordinates if available
    y = None
    mask = None
    if labels_df is not None:
        target_labels = labels_df[labels_df['ID'].str.startswith(target_id + '_')]
        
        # Sort by residue ID to match sequence order
        target_labels = target_labels.sort_values(by='resid')
        
        # Check if we have the expected number of residues
        if len(target_labels) == len(sequence):
            # Extract coordinates for each residue
            coordinates = target_labels[['x_1', 'y_1', 'z_1']].values
            
            # Create a mask for NaN values (1 for valid, 0 for NaN)
            valid_mask = ~np.isnan(coordinates).any(axis=1)
            mask = torch.tensor(valid_mask, dtype=torch.float)
            
            # Replace NaN with zeros (we'll mask these during loss calculation)
            coordinates = np.nan_to_num(coordinates, nan=0.0)
            
            y = torch.tensor(coordinates, dtype=torch.float)
        else:
            print(f"Warning: Mismatch in sequence length and label count for {target_id}")
    
    # Create the data object with properly typed target_id (as string)
    data = Data(x=x, edge_index=edge_index, y=y, mask=mask)
    
    # Store target_id as a string attribute
    data.target_id = str(target_id)
    
    return data

In [8]:
def create_dataset(sequences_df, labels_df=None):
    dataset = []
    skipped_count = 0
    nan_count = 0
    
    for idx, row in tqdm(sequences_df.iterrows(), total=len(sequences_df)):
        target_id = row['target_id']
        sequence = row['sequence']
        
        # Clean sequence - replace any non-standard nucleotides with 'N'
        # and count how many non-standard nucleotides there are
        cleaned_sequence = ''
        non_standard_count = 0
        
        for nt in sequence:
            if nt in nucleotide_mapping:
                cleaned_sequence += nt
            else:
                cleaned_sequence += 'N'  # Placeholder for non-standard nucleotides
                non_standard_count += 1
        
        # If too many non-standard nucleotides (>10%), skip this sequence
        if non_standard_count / len(sequence) > 0.1:
            print(f"Skipping sequence {target_id} with {non_standard_count} non-standard nucleotides")
            skipped_count += 1
            continue
        
        # Create graph
        graph = sequence_to_graph(cleaned_sequence, target_id, labels_df)
        
        # Check if we have labels with many NaN values
        if labels_df is not None and hasattr(graph, 'mask') and graph.mask is not None:
            nan_percentage = 1.0 - torch.mean(graph.mask).item()
            if nan_percentage > 0.5:  # If more than 50% coordinates are NaN
                print(f"Warning: Sequence {target_id} has {nan_percentage:.1%} NaN coordinates")
                nan_count += 1
        
        # Add to dataset if no labels needed or valid labels exist
        if labels_df is None or graph.y is not None:
            dataset.append(graph)
    
    print(f"Dataset creation: {skipped_count} sequences skipped due to non-standard nucleotides")
    print(f"Dataset creation: {nan_count} sequences have >50% NaN coordinates")
    
    return dataset

In [9]:

data={
      "sequence":train_sequences['sequence'].to_list(),
      "temporal_cutoff": train_sequences['temporal_cutoff'].to_list(),
      "description": train_sequences['description'].to_list(),
      "all_sequences": train_sequences['all_sequences'].to_list(),
}
config = {
    "cutoff_date": "2020-01-01",
    "test_cutoff_date": "2022-05-01",
}

In [10]:
# Split data into train and test
all_index = np.arange(len(data['sequence']))
cutoff_date = pd.Timestamp(config['cutoff_date'])
test_cutoff_date = pd.Timestamp(config['test_cutoff_date'])
train_index = [i for i, d in enumerate(data['temporal_cutoff']) if pd.Timestamp(d) <= cutoff_date]
test_index = [i for i, d in enumerate(data['temporal_cutoff']) if pd.Timestamp(d) > cutoff_date and pd.Timestamp(d) <= test_cutoff_date]

In [11]:
# Create training dataset
train_dataset = create_dataset(train_sequences, train_labels)
print(f"Created {len(train_dataset)} graph data objects for training")

 13%|█▎        | 113/844 [00:04<00:25, 29.01it/s]



 16%|█▌        | 134/844 [00:04<00:24, 28.82it/s]



 17%|█▋        | 140/844 [00:04<00:24, 28.59it/s]



 17%|█▋        | 147/844 [00:05<00:23, 29.15it/s]



 20%|██        | 169/844 [00:05<00:23, 29.26it/s]



 21%|██        | 178/844 [00:06<00:22, 29.05it/s]



 22%|██▏       | 184/844 [00:06<00:22, 28.90it/s]



 25%|██▍       | 208/844 [00:07<00:22, 28.23it/s]



 26%|██▌       | 217/844 [00:07<00:27, 23.13it/s]



 28%|██▊       | 235/844 [00:08<00:22, 26.64it/s]



 29%|██▉       | 244/844 [00:08<00:27, 21.67it/s]



 34%|███▎      | 283/844 [00:10<00:19, 28.70it/s]



 36%|███▌      | 304/844 [00:11<00:20, 26.69it/s]



 37%|███▋      | 313/844 [00:11<00:20, 26.32it/s]



 41%|████      | 346/844 [00:12<00:19, 25.93it/s]



 51%|█████     | 427/844 [00:15<00:15, 27.10it/s]



 52%|█████▏    | 442/844 [00:16<00:15, 26.19it/s]



 53%|█████▎    | 451/844 [00:16<00:14, 26.24it/s]



 54%|█████▍    | 457/844 [00:16<00:14, 26.15it/s]



 55%|█████▍    | 463/844 [00:16<00:14, 26.65it/s]



 60%|██████    | 508/844 [00:18<00:12, 27.23it/s]



 69%|██████▊   | 580/844 [00:21<00:10, 26.21it/s]



 70%|███████   | 595/844 [00:21<00:09, 26.27it/s]



 73%|███████▎  | 616/844 [00:22<00:08, 26.50it/s]



 76%|███████▌  | 643/844 [00:23<00:07, 26.37it/s]



 80%|███████▉  | 673/844 [00:24<00:06, 26.64it/s]



 80%|████████  | 679/844 [00:25<00:06, 26.88it/s]



 82%|████████▏ | 691/844 [00:25<00:05, 27.61it/s]



 94%|█████████▎| 790/844 [00:29<00:02, 25.32it/s]



 95%|█████████▌| 805/844 [00:30<00:01, 24.39it/s]



 96%|█████████▌| 811/844 [00:30<00:01, 25.40it/s]



100%|██████████| 844/844 [00:31<00:00, 26.70it/s]

Dataset creation: 0 sequences skipped due to non-standard nucleotides
Dataset creation: 69 sequences have >50% NaN coordinates
Created 844 graph data objects for training





In [12]:
train_graphs = train_dataset[:len(train_index)]
val_graphs = train_dataset[:len(train_index)]

In [13]:
# Define the GNN model
class RNAStructurePredictor(nn.Module):
    def __init__(self, input_dim, hidden_dim=256, output_dim=3, num_layers=10, max_seq_len=MAX_SEQ_LEN):
        super(RNAStructurePredictor, self).__init__()
        
        # Initial embedding layer
        self.embedding = nn.Linear(input_dim, hidden_dim)
        
        # SAGEConv layers
        self.conv_layers = nn.ModuleList()
        for _ in range(num_layers):
            self.conv_layers.append(SAGEConv(hidden_dim, hidden_dim))
        
        # Output layer for 3D coordinates prediction (x, y, z)
        self.output = nn.Linear(hidden_dim, output_dim)
        
        # Add attention mechanism
        self.attention = nn.Sequential(
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        
        # Position encoding - increase max sequence length
        self.position_encoder = nn.Embedding(max_seq_len, hidden_dim)
        
        # Initialize parameters with Xavier/Glorot
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
        
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        
        # Initial embedding
        x = self.embedding(x)
        
        # Add positional information with bounds checking
        max_pos = self.position_encoder.weight.size(0) - 1  # Maximum allowed index
        pos = torch.arange(x.size(0), device=x.device)
        # Clamp position indices to avoid out-of-bounds errors
        pos = torch.clamp(pos, max=max_pos)
        x = x + self.position_encoder(pos)
        
        # Graph convolution layers
        for conv in self.conv_layers:
            x_residual = x
            x = conv(x, edge_index)
            x = F.relu(x)
            x = x + x_residual  # Skip connection
            x = F.dropout(x, p=0.2, training=self.training)
        
        # Apply attention
        attention_weights = self.attention(x)
        x = x * attention_weights
        
        # Predict 3D coordinates
        coordinates = self.output(x)
        
        return coordinates

In [14]:
# Define loss function for 3D coordinate prediction
def rmsd_loss(pred, target, mask=None):
    """
    Root Mean Square Deviation (RMSD) loss function with optional masking for NaN values.
    Lower RMSD indicates better structural similarity.
    
    Args:
        pred: Predicted coordinates, shape (n_nucleotides, 3)
        target: Target coordinates, shape (n_nucleotides, 3)
        mask: Optional mask for valid values, shape (n_nucleotides,)
    """
    squared_diff = torch.sum((pred - target) ** 2, dim=1)
    
    if mask is not None:
        # Apply mask to consider only valid coordinates
        # Ensure we don't divide by zero by adding a small epsilon to the sum
        masked_squared_diff = squared_diff * mask
        mean_squared_diff = torch.sum(masked_squared_diff) / (torch.sum(mask) + 1e-10)
    else:
        mean_squared_diff = torch.mean(squared_diff)
    
    rmsd = torch.sqrt(mean_squared_diff)
    return rmsd

In [15]:
def calculate_distance_matrix(X,Y,epsilon=1e-4):
    return (torch.square(X[:,None]-Y[None,:])+epsilon).sum(-1).sqrt()


def dRMSD(pred_x,
          pred_y,
          gt_x,
          gt_y,
          epsilon=1e-4,Z=10,d_clamp=None):
    pred_dm=calculate_distance_matrix(pred_x,pred_y)
    gt_dm=calculate_distance_matrix(gt_x,gt_y)



    mask=~torch.isnan(gt_dm)
    mask[torch.eye(mask.shape[0]).bool()]=False

    if d_clamp is not None:
        rmsd=(torch.square(pred_dm[mask]-gt_dm[mask])+epsilon).clip(0,d_clamp**2)
    else:
        rmsd=torch.square(pred_dm[mask]-gt_dm[mask])+epsilon

    return rmsd.sqrt().mean()/Z

def local_dRMSD(pred_x,
          pred_y,
          gt_x,
          gt_y,
          epsilon=1e-4,Z=10,d_clamp=30):
    pred_dm=calculate_distance_matrix(pred_x,pred_y)
    gt_dm=calculate_distance_matrix(gt_x,gt_y)



    mask=(~torch.isnan(gt_dm))*(gt_dm<d_clamp)
    mask[torch.eye(mask.shape[0]).bool()]=False



    rmsd=torch.square(pred_dm[mask]-gt_dm[mask])+epsilon
    # rmsd=(torch.square(pred_dm[mask]-gt_dm[mask])+epsilon).sqrt()/Z
    #rmsd=torch.abs(pred_dm[mask]-gt_dm[mask])/Z
    return rmsd.sqrt().mean()/Z

def dRMAE(pred_x,
          pred_y,
          gt_x,
          gt_y,
          epsilon=1e-4,Z=10,d_clamp=None):
    pred_dm=calculate_distance_matrix(pred_x,pred_y)
    gt_dm=calculate_distance_matrix(gt_x,gt_y)



    mask=~torch.isnan(gt_dm)
    mask[torch.eye(mask.shape[0]).bool()]=False

    rmsd=torch.abs(pred_dm[mask]-gt_dm[mask])

    return rmsd.mean()/Z

import torch

def align_svd_mae(input, target, Z=10):
    """
    Aligns the input (Nx3) to target (Nx3) using SVD-based Procrustes alignment
    and computes RMSD loss.
    
    Args:
        input (torch.Tensor): Nx3 tensor representing the input points.
        target (torch.Tensor): Nx3 tensor representing the target points.
    
    Returns:
        aligned_input (torch.Tensor): Nx3 aligned input.
        rmsd_loss (torch.Tensor): RMSD loss.
    """
    assert input.shape == target.shape, "Input and target must have the same shape"

    #mask 
    mask=~torch.isnan(target.sum(-1))

    input=input[mask]
    target=target[mask]
    
    # Compute centroids
    centroid_input = input.mean(dim=0, keepdim=True)
    centroid_target = target.mean(dim=0, keepdim=True)

    # Center the points
    input_centered = input - centroid_input.detach()
    target_centered = target - centroid_target

    # Compute covariance matrix
    cov_matrix = input_centered.T @ target_centered

    # SVD to find optimal rotation
    U, S, Vt = torch.svd(cov_matrix)

    # Compute rotation matrix
    R = Vt @ U.T

    # Ensure a proper rotation (det(R) = 1, no reflection)
    if torch.det(R) < 0:
        Vt[-1, :] *= -1
        R = Vt @ U.T

    # Rotate input
    aligned_input = (input_centered @ R.T.detach()) + centroid_target.detach()

    # # Compute RMSD loss
    # rmsd_loss = torch.sqrt(((aligned_input - target) ** 2).mean())

    # rmsd_loss = torch.sqrt(((aligned_input - target) ** 2).mean())
    
    # return aligned_input, rmsd_loss
    return torch.abs(aligned_input-target).mean()/Z

In [16]:
def train(model, train_loader, optimizer, device):
    model.train()
    total_loss = 0
    loss_values = []
    # Create tqdm progress bar with loss display
    pbar = tqdm(train_loader, desc='Training')
    
    for data in pbar:
        data = data.to(device)
        optimizer.zero_grad()
        
        # Forward pass
        pred = model(data)
        
        # Calculate loss if labels exist
        if data.y is not None:
            # Use mask if available
            if hasattr(data, 'mask') and data.mask is not None:
                loss = dRMAE(pred,pred,data.y,data.y) + align_svd_mae(pred, data.y)
                # loss = rmsd_loss(pred, data.y, data.mask)
            else:
                loss = dRMAE(pred,pred,data.y,data.y) + align_svd_mae(pred, data.y)
                # loss = rmsd_loss(pred, data.y)
                
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            loss_values.append(loss.item())
            
            # Update progress bar with current loss
            pbar.set_postfix({'loss': f'{loss.item():.4f}', 'smooth loss': np.mean(loss_values[-100:])})
    
    avg_loss = total_loss / len(train_loader)
    return avg_loss

def validate(model, val_loader, device):
    model.eval()
    total_loss = 0
    
    # Create tqdm progress bar with loss display
    pbar = tqdm(val_loader, desc='Validation')
    
    with torch.no_grad():
        for data in pbar:
            data = data.to(device)
            pred = model(data)
            
            if data.y is not None:
                # Use mask if available
                if hasattr(data, 'mask') and data.mask is not None:
                    loss = dRMAE(pred,pred,data.y,data.y) + align_svd_mae(pred, data.y)
                    # loss = rmsd_loss(pred, data.y, data.mask)
                else:
                    loss = dRMAE(pred,pred,data.y,data.y) + align_svd_mae(pred, data.y)
                    # loss = rmsd_loss(pred, data.y)
                total_loss += loss.item()
                
                # Update progress bar with current loss
                pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    avg_loss = total_loss / len(val_loader)
    return avg_loss

In [17]:
# Function to make predictions on test data
def predict(model, test_loader, device):
    model.eval()
    predictions = {}
    
    with torch.no_grad():
        for data in test_loader:
            data = data.to(device)
            pred = model(data)
            
            # Store predictions
            target_id = data.target_id
            
            # If we have ground truth and mask, report metrics
            if hasattr(data, 'y') and data.y is not None:
                if hasattr(data, 'mask') and data.mask is not None:
                    loss = rmsd_loss(pred, data.y, data.mask).item()
                else:
                    loss = rmsd_loss(pred, data.y).item()
                print(f"Prediction for {target_id}, RMSD: {loss:.4f}")
            
            predictions[target_id] = pred.cpu().numpy()
    
    return predictions

In [18]:
# Setup for training
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
device = torch.device(device)
print(f"Using device: {device}")

Using device: cuda


In [19]:
# Create data loaders
train_loader = DataLoader(train_graphs, batch_size=8, shuffle=True, num_workers=4)
val_loader = DataLoader(val_graphs, batch_size=8, shuffle=False, num_workers=4)



In [20]:
# Initialize model
input_dim = 4  # One-hot encoding dimension for nucleotides
model = RNAStructurePredictor(input_dim, hidden_dim=1024, output_dim=3, num_layers=15, max_seq_len=10000).to(device)
print(f"Model initialized with max sequence length of 10000")

Model initialized with max sequence length of 10000


In [21]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.00003)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=12)

In [22]:
# Training loop
num_epochs = 2000
best_val_loss = float('inf')
early_stopping_patience = 150
early_stopping_counter = 0

train_losses = []
val_losses = []

print("Starting training...")
for epoch in range(num_epochs):
    # Train
    train_loss = train(model, train_loader, optimizer, device)
    train_losses.append(train_loss)
    
    # Validate
    val_loss = validate(model, val_loader, device)
    val_losses.append(val_loss)
    
    # Learning rate scheduler
    scheduler.step(val_loss)
    
    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        early_stopping_counter = 0
        # Save best model
        torch.save(model.state_dict(), "best_rna_structure_model.pt")
    else:
        early_stopping_counter += 1
    
    print(f"Epoch {epoch+1}/{num_epochs}, "
          f"Train Loss: {train_loss:.4f}, "
          f"Val Loss: {val_loss:.4f}, "
          f"LR: {optimizer.param_groups[0]['lr']:.6f}")
    
    if early_stopping_counter >= early_stopping_patience:
        print(f"Early stopping triggered after {epoch+1} epochs")
        break

print("Training completed!")

Starting training...


Training: 100%|██████████| 76/76 [00:09<00:00,  8.36it/s, loss=7.1510, smooth loss=18.5]
Validation: 100%|██████████| 76/76 [00:04<00:00, 16.73it/s, loss=21.8392]


Epoch 1/2000, Train Loss: 18.5489, Val Loss: 13.3715, LR: 0.000030


Training: 100%|██████████| 76/76 [00:06<00:00, 11.10it/s, loss=2.9815, smooth loss=17.7]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.85it/s, loss=21.8392]


Epoch 2/2000, Train Loss: 17.6925, Val Loss: 13.3715, LR: 0.000030


Training: 100%|██████████| 76/76 [00:07<00:00, 10.49it/s, loss=5.1950, smooth loss=17.9]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.68it/s, loss=21.8392]


Epoch 3/2000, Train Loss: 17.8774, Val Loss: 13.3715, LR: 0.000030


Training: 100%|██████████| 76/76 [00:06<00:00, 11.30it/s, loss=3.8799, smooth loss=18.1]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.43it/s, loss=21.8392]


Epoch 4/2000, Train Loss: 18.1151, Val Loss: 13.3715, LR: 0.000030


Training: 100%|██████████| 76/76 [00:06<00:00, 11.43it/s, loss=2.8164, smooth loss=18.6]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.37it/s, loss=21.8392]


Epoch 5/2000, Train Loss: 18.5997, Val Loss: 13.3715, LR: 0.000030


Training: 100%|██████████| 76/76 [00:06<00:00, 11.42it/s, loss=9.4302, smooth loss=17.9]
Validation: 100%|██████████| 76/76 [00:03<00:00, 19.17it/s, loss=21.8392]


Epoch 6/2000, Train Loss: 17.9282, Val Loss: 13.3715, LR: 0.000030


Training: 100%|██████████| 76/76 [00:06<00:00, 11.76it/s, loss=11.2863, smooth loss=17.8]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.87it/s, loss=21.8392]


Epoch 7/2000, Train Loss: 17.8089, Val Loss: 13.3715, LR: 0.000030


Training: 100%|██████████| 76/76 [00:06<00:00, 11.58it/s, loss=2.7140, smooth loss=18.1]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.72it/s, loss=21.8392]


Epoch 8/2000, Train Loss: 18.1212, Val Loss: 13.3715, LR: 0.000030


Training: 100%|██████████| 76/76 [00:06<00:00, 11.65it/s, loss=3.2506, smooth loss=18.1]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.78it/s, loss=21.8392]


Epoch 9/2000, Train Loss: 18.1055, Val Loss: 13.3715, LR: 0.000030


Training: 100%|██████████| 76/76 [00:06<00:00, 11.18it/s, loss=3.2058, smooth loss=17.9]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.78it/s, loss=21.8392]


Epoch 10/2000, Train Loss: 17.9360, Val Loss: 13.3715, LR: 0.000030


Training: 100%|██████████| 76/76 [00:06<00:00, 11.36it/s, loss=8.1641, smooth loss=18.3]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.88it/s, loss=21.8392]


Epoch 11/2000, Train Loss: 18.2748, Val Loss: 13.3715, LR: 0.000030


Training: 100%|██████████| 76/76 [00:06<00:00, 11.46it/s, loss=5.6264, smooth loss=19.5]
Validation: 100%|██████████| 76/76 [00:03<00:00, 19.10it/s, loss=21.8392]


Epoch 12/2000, Train Loss: 19.5480, Val Loss: 13.3715, LR: 0.000030


Training: 100%|██████████| 76/76 [00:06<00:00, 11.39it/s, loss=6.4467, smooth loss=18]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.81it/s, loss=21.8392]


Epoch 13/2000, Train Loss: 17.9748, Val Loss: 13.3715, LR: 0.000030


Training: 100%|██████████| 76/76 [00:06<00:00, 11.70it/s, loss=25.5335, smooth loss=18.5]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.38it/s, loss=21.8392]


Epoch 14/2000, Train Loss: 18.4801, Val Loss: 13.3715, LR: 0.000015


Training: 100%|██████████| 76/76 [00:06<00:00, 11.51it/s, loss=21.0850, smooth loss=19]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.54it/s, loss=21.8392]


Epoch 15/2000, Train Loss: 18.9708, Val Loss: 13.3715, LR: 0.000015


Training: 100%|██████████| 76/76 [00:06<00:00, 11.02it/s, loss=2.0128, smooth loss=17.7]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.22it/s, loss=21.8392]


Epoch 16/2000, Train Loss: 17.7144, Val Loss: 13.3715, LR: 0.000015


Training: 100%|██████████| 76/76 [00:06<00:00, 11.46it/s, loss=18.0775, smooth loss=19]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.29it/s, loss=21.8392]


Epoch 17/2000, Train Loss: 18.9832, Val Loss: 13.3715, LR: 0.000015


Training: 100%|██████████| 76/76 [00:06<00:00, 11.27it/s, loss=0.8955, smooth loss=19.1]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.24it/s, loss=21.8392]


Epoch 18/2000, Train Loss: 19.1089, Val Loss: 13.3715, LR: 0.000015


Training: 100%|██████████| 76/76 [00:06<00:00, 11.33it/s, loss=7.6478, smooth loss=18.3]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.57it/s, loss=21.8392]


Epoch 19/2000, Train Loss: 18.2845, Val Loss: 13.3715, LR: 0.000015


Training: 100%|██████████| 76/76 [00:06<00:00, 11.32it/s, loss=2.9936, smooth loss=18]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.44it/s, loss=21.8392]


Epoch 20/2000, Train Loss: 17.9747, Val Loss: 13.3715, LR: 0.000015


Training: 100%|██████████| 76/76 [00:06<00:00, 11.51it/s, loss=16.1580, smooth loss=19.3]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.77it/s, loss=21.8392]


Epoch 21/2000, Train Loss: 19.2503, Val Loss: 13.3715, LR: 0.000015


Training: 100%|██████████| 76/76 [00:06<00:00, 11.55it/s, loss=12.1986, smooth loss=18.6]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.82it/s, loss=21.8392]


Epoch 22/2000, Train Loss: 18.5548, Val Loss: 13.3715, LR: 0.000015


Training: 100%|██████████| 76/76 [00:06<00:00, 11.64it/s, loss=8.8808, smooth loss=17]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.58it/s, loss=21.8392]


Epoch 23/2000, Train Loss: 17.0259, Val Loss: 13.3715, LR: 0.000015


Training: 100%|██████████| 76/76 [00:06<00:00, 11.22it/s, loss=7.8038, smooth loss=19.3]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.56it/s, loss=21.8392]


Epoch 24/2000, Train Loss: 19.3053, Val Loss: 13.3715, LR: 0.000015


Training: 100%|██████████| 76/76 [00:06<00:00, 11.41it/s, loss=3.7466, smooth loss=19.2]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.54it/s, loss=21.8392]


Epoch 25/2000, Train Loss: 19.1864, Val Loss: 13.3715, LR: 0.000015


Training: 100%|██████████| 76/76 [00:06<00:00, 11.53it/s, loss=16.3719, smooth loss=18.7]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.67it/s, loss=21.8392]


Epoch 26/2000, Train Loss: 18.6942, Val Loss: 13.3715, LR: 0.000015


Training: 100%|██████████| 76/76 [00:06<00:00, 11.20it/s, loss=3.1368, smooth loss=18.5]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.62it/s, loss=21.8392]


Epoch 27/2000, Train Loss: 18.5485, Val Loss: 13.3715, LR: 0.000008


Training: 100%|██████████| 76/76 [00:06<00:00, 11.44it/s, loss=1.3284, smooth loss=17.5]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.36it/s, loss=21.8392]


Epoch 28/2000, Train Loss: 17.5395, Val Loss: 13.3715, LR: 0.000008


Training: 100%|██████████| 76/76 [00:06<00:00, 11.38it/s, loss=1.1960, smooth loss=19.9]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.75it/s, loss=21.8392]


Epoch 29/2000, Train Loss: 19.8831, Val Loss: 13.3715, LR: 0.000008


Training: 100%|██████████| 76/76 [00:06<00:00, 11.54it/s, loss=19.2521, smooth loss=18.5]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.46it/s, loss=21.8392]


Epoch 30/2000, Train Loss: 18.4676, Val Loss: 13.3715, LR: 0.000008


Training: 100%|██████████| 76/76 [00:06<00:00, 10.95it/s, loss=6.6708, smooth loss=18.8]
Validation: 100%|██████████| 76/76 [00:04<00:00, 17.92it/s, loss=21.8392]


Epoch 31/2000, Train Loss: 18.7743, Val Loss: 13.3715, LR: 0.000008


Training: 100%|██████████| 76/76 [00:06<00:00, 11.54it/s, loss=34.9586, smooth loss=17.9]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.65it/s, loss=21.8392]


Epoch 32/2000, Train Loss: 17.8974, Val Loss: 13.3715, LR: 0.000008


Training: 100%|██████████| 76/76 [00:06<00:00, 11.52it/s, loss=7.7592, smooth loss=18.4]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.15it/s, loss=21.8392]


Epoch 33/2000, Train Loss: 18.3678, Val Loss: 13.3715, LR: 0.000008


Training: 100%|██████████| 76/76 [00:06<00:00, 11.11it/s, loss=6.6550, smooth loss=19.3]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.28it/s, loss=21.8392]


Epoch 34/2000, Train Loss: 19.3318, Val Loss: 13.3715, LR: 0.000008


Training: 100%|██████████| 76/76 [00:06<00:00, 11.68it/s, loss=21.9965, smooth loss=17.4]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.68it/s, loss=21.8392]


Epoch 35/2000, Train Loss: 17.3515, Val Loss: 13.3715, LR: 0.000008


Training: 100%|██████████| 76/76 [00:06<00:00, 11.74it/s, loss=1.7498, smooth loss=17.6]
Validation: 100%|██████████| 76/76 [00:04<00:00, 17.88it/s, loss=21.8392]


Epoch 36/2000, Train Loss: 17.5885, Val Loss: 13.3715, LR: 0.000008


Training: 100%|██████████| 76/76 [00:06<00:00, 11.25it/s, loss=26.4492, smooth loss=18.3]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.30it/s, loss=21.8392]


Epoch 37/2000, Train Loss: 18.3094, Val Loss: 13.3715, LR: 0.000008


Training: 100%|██████████| 76/76 [00:06<00:00, 11.69it/s, loss=17.6762, smooth loss=17.6]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.75it/s, loss=21.8392]


Epoch 38/2000, Train Loss: 17.6109, Val Loss: 13.3715, LR: 0.000008


Training: 100%|██████████| 76/76 [00:06<00:00, 11.46it/s, loss=5.7311, smooth loss=19.7]
Validation: 100%|██████████| 76/76 [00:04<00:00, 16.76it/s, loss=21.8392]


Epoch 39/2000, Train Loss: 19.7419, Val Loss: 13.3715, LR: 0.000008


Training: 100%|██████████| 76/76 [00:06<00:00, 11.35it/s, loss=5.3899, smooth loss=18.1]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.50it/s, loss=21.8392]


Epoch 40/2000, Train Loss: 18.0563, Val Loss: 13.3715, LR: 0.000004


Training: 100%|██████████| 76/76 [00:06<00:00, 11.71it/s, loss=12.8253, smooth loss=18.6]
Validation: 100%|██████████| 76/76 [00:04<00:00, 19.00it/s, loss=21.8392]


Epoch 41/2000, Train Loss: 18.6213, Val Loss: 13.3715, LR: 0.000004


Training: 100%|██████████| 76/76 [00:06<00:00, 11.38it/s, loss=13.2026, smooth loss=18]
Validation: 100%|██████████| 76/76 [00:04<00:00, 16.37it/s, loss=21.8392]


Epoch 42/2000, Train Loss: 18.0490, Val Loss: 13.3715, LR: 0.000004


Training: 100%|██████████| 76/76 [00:06<00:00, 11.68it/s, loss=11.6778, smooth loss=18.3]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.71it/s, loss=21.8392]


Epoch 43/2000, Train Loss: 18.3275, Val Loss: 13.3715, LR: 0.000004


Training: 100%|██████████| 76/76 [00:06<00:00, 11.56it/s, loss=3.2803, smooth loss=17.6]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.37it/s, loss=21.8392]


Epoch 44/2000, Train Loss: 17.5942, Val Loss: 13.3715, LR: 0.000004


Training: 100%|██████████| 76/76 [00:06<00:00, 11.57it/s, loss=4.3228, smooth loss=18.3]
Validation: 100%|██████████| 76/76 [00:04<00:00, 16.75it/s, loss=21.8392]


Epoch 45/2000, Train Loss: 18.2693, Val Loss: 13.3715, LR: 0.000004


Training: 100%|██████████| 76/76 [00:06<00:00, 11.43it/s, loss=3.7723, smooth loss=18.8]
Validation: 100%|██████████| 76/76 [00:03<00:00, 19.01it/s, loss=21.8392]


Epoch 46/2000, Train Loss: 18.7982, Val Loss: 13.3715, LR: 0.000004


Training: 100%|██████████| 76/76 [00:06<00:00, 11.69it/s, loss=7.8030, smooth loss=18.6]
Validation: 100%|██████████| 76/76 [00:03<00:00, 19.04it/s, loss=21.8392]


Epoch 47/2000, Train Loss: 18.5994, Val Loss: 13.3715, LR: 0.000004


Training: 100%|██████████| 76/76 [00:06<00:00, 11.54it/s, loss=6.9685, smooth loss=17.9]
Validation: 100%|██████████| 76/76 [00:04<00:00, 15.97it/s, loss=21.8392]


Epoch 48/2000, Train Loss: 17.8760, Val Loss: 13.3715, LR: 0.000004


Training: 100%|██████████| 76/76 [00:06<00:00, 11.11it/s, loss=25.6010, smooth loss=18.4]
Validation: 100%|██████████| 76/76 [00:04<00:00, 17.73it/s, loss=21.8392]


Epoch 49/2000, Train Loss: 18.4381, Val Loss: 13.3715, LR: 0.000004


Training: 100%|██████████| 76/76 [00:06<00:00, 11.41it/s, loss=12.6014, smooth loss=17.6]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.17it/s, loss=21.8392]


Epoch 50/2000, Train Loss: 17.6116, Val Loss: 13.3715, LR: 0.000004


Training: 100%|██████████| 76/76 [00:06<00:00, 11.24it/s, loss=18.6284, smooth loss=18.5]
Validation: 100%|██████████| 76/76 [00:04<00:00, 17.46it/s, loss=21.8392]


Epoch 51/2000, Train Loss: 18.5067, Val Loss: 13.3715, LR: 0.000004


Training: 100%|██████████| 76/76 [00:06<00:00, 11.57it/s, loss=5.9350, smooth loss=18.1]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.24it/s, loss=21.8392]


Epoch 52/2000, Train Loss: 18.1100, Val Loss: 13.3715, LR: 0.000004


Training: 100%|██████████| 76/76 [00:06<00:00, 10.93it/s, loss=10.6145, smooth loss=19.6]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.39it/s, loss=21.8392]


Epoch 53/2000, Train Loss: 19.6334, Val Loss: 13.3715, LR: 0.000002


Training: 100%|██████████| 76/76 [00:06<00:00, 11.44it/s, loss=5.9034, smooth loss=19.1]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.68it/s, loss=21.8392]


Epoch 54/2000, Train Loss: 19.1438, Val Loss: 13.3715, LR: 0.000002


Training: 100%|██████████| 76/76 [00:06<00:00, 11.53it/s, loss=2.3544, smooth loss=18.2]
Validation: 100%|██████████| 76/76 [00:03<00:00, 19.03it/s, loss=21.8392]


Epoch 55/2000, Train Loss: 18.1842, Val Loss: 13.3715, LR: 0.000002


Training: 100%|██████████| 76/76 [00:06<00:00, 11.56it/s, loss=7.2104, smooth loss=18.9]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.82it/s, loss=21.8392]


Epoch 56/2000, Train Loss: 18.9043, Val Loss: 13.3715, LR: 0.000002


Training: 100%|██████████| 76/76 [00:06<00:00, 11.22it/s, loss=1.6479, smooth loss=18.7]
Validation: 100%|██████████| 76/76 [00:03<00:00, 19.03it/s, loss=21.8392]


Epoch 57/2000, Train Loss: 18.6856, Val Loss: 13.3715, LR: 0.000002


Training: 100%|██████████| 76/76 [00:06<00:00, 11.55it/s, loss=14.3530, smooth loss=18.1]
Validation: 100%|██████████| 76/76 [00:03<00:00, 19.10it/s, loss=21.8392]


Epoch 58/2000, Train Loss: 18.1449, Val Loss: 13.3715, LR: 0.000002


Training: 100%|██████████| 76/76 [00:06<00:00, 11.67it/s, loss=4.3608, smooth loss=18.8]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.82it/s, loss=21.8392]


Epoch 59/2000, Train Loss: 18.8439, Val Loss: 13.3715, LR: 0.000002


Training: 100%|██████████| 76/76 [00:06<00:00, 11.61it/s, loss=2.3710, smooth loss=17.8]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.59it/s, loss=21.8392]


Epoch 60/2000, Train Loss: 17.7807, Val Loss: 13.3715, LR: 0.000002


Training: 100%|██████████| 76/76 [00:06<00:00, 11.34it/s, loss=9.5153, smooth loss=18.4]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.78it/s, loss=21.8392]


Epoch 61/2000, Train Loss: 18.3616, Val Loss: 13.3715, LR: 0.000002


Training: 100%|██████████| 76/76 [00:06<00:00, 11.43it/s, loss=3.0160, smooth loss=17.9]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.84it/s, loss=21.8392]


Epoch 62/2000, Train Loss: 17.8689, Val Loss: 13.3715, LR: 0.000002


Training: 100%|██████████| 76/76 [00:06<00:00, 10.99it/s, loss=3.0122, smooth loss=18.3]
Validation: 100%|██████████| 76/76 [00:03<00:00, 19.28it/s, loss=21.8392]


Epoch 63/2000, Train Loss: 18.2574, Val Loss: 13.3715, LR: 0.000002


Training: 100%|██████████| 76/76 [00:06<00:00, 11.51it/s, loss=9.2709, smooth loss=17.9]
Validation: 100%|██████████| 76/76 [00:03<00:00, 19.00it/s, loss=21.8392]


Epoch 64/2000, Train Loss: 17.8544, Val Loss: 13.3715, LR: 0.000002


Training: 100%|██████████| 76/76 [00:06<00:00, 11.30it/s, loss=11.8520, smooth loss=19.5]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.63it/s, loss=21.8392]


Epoch 65/2000, Train Loss: 19.4808, Val Loss: 13.3715, LR: 0.000002


Training: 100%|██████████| 76/76 [00:06<00:00, 11.02it/s, loss=13.7626, smooth loss=19]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.16it/s, loss=21.8392]


Epoch 66/2000, Train Loss: 19.0321, Val Loss: 13.3715, LR: 0.000001


Training: 100%|██████████| 76/76 [00:06<00:00, 11.36it/s, loss=5.7589, smooth loss=17.6]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.27it/s, loss=21.8392]


Epoch 67/2000, Train Loss: 17.6441, Val Loss: 13.3715, LR: 0.000001


Training: 100%|██████████| 76/76 [00:06<00:00, 10.98it/s, loss=12.3444, smooth loss=17.9]
Validation: 100%|██████████| 76/76 [00:04<00:00, 16.96it/s, loss=21.8392]


Epoch 68/2000, Train Loss: 17.9347, Val Loss: 13.3715, LR: 0.000001


Training: 100%|██████████| 76/76 [00:06<00:00, 10.98it/s, loss=14.6388, smooth loss=17.3]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.67it/s, loss=21.8392]


Epoch 69/2000, Train Loss: 17.3392, Val Loss: 13.3715, LR: 0.000001


Training: 100%|██████████| 76/76 [00:06<00:00, 11.44it/s, loss=4.6071, smooth loss=18.3]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.38it/s, loss=21.8392]


Epoch 70/2000, Train Loss: 18.2789, Val Loss: 13.3715, LR: 0.000001


Training: 100%|██████████| 76/76 [00:06<00:00, 11.57it/s, loss=2.6625, smooth loss=18.8]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.11it/s, loss=21.8392]


Epoch 71/2000, Train Loss: 18.7524, Val Loss: 13.3715, LR: 0.000001


Training: 100%|██████████| 76/76 [00:06<00:00, 11.38it/s, loss=8.6630, smooth loss=18.5]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.41it/s, loss=21.8392]


Epoch 72/2000, Train Loss: 18.4749, Val Loss: 13.3715, LR: 0.000001


Training: 100%|██████████| 76/76 [00:06<00:00, 11.40it/s, loss=10.3777, smooth loss=18.7]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.05it/s, loss=21.8392]


Epoch 73/2000, Train Loss: 18.6620, Val Loss: 13.3715, LR: 0.000001


Training: 100%|██████████| 76/76 [00:07<00:00, 10.74it/s, loss=4.8827, smooth loss=18.8]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.50it/s, loss=21.8392]


Epoch 74/2000, Train Loss: 18.8183, Val Loss: 13.3715, LR: 0.000001


Training: 100%|██████████| 76/76 [00:06<00:00, 11.32it/s, loss=12.0415, smooth loss=17.2]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.75it/s, loss=21.8392]


Epoch 75/2000, Train Loss: 17.1713, Val Loss: 13.3715, LR: 0.000001


Training: 100%|██████████| 76/76 [00:06<00:00, 11.21it/s, loss=18.7303, smooth loss=18.1]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.40it/s, loss=21.8392]


Epoch 76/2000, Train Loss: 18.0547, Val Loss: 13.3715, LR: 0.000001


Training: 100%|██████████| 76/76 [00:06<00:00, 11.62it/s, loss=4.6456, smooth loss=17.7]
Validation: 100%|██████████| 76/76 [00:04<00:00, 17.97it/s, loss=21.8392]


Epoch 77/2000, Train Loss: 17.6667, Val Loss: 13.3715, LR: 0.000001


Training: 100%|██████████| 76/76 [00:06<00:00, 11.27it/s, loss=28.9344, smooth loss=17.7]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.21it/s, loss=21.8392]


Epoch 78/2000, Train Loss: 17.6506, Val Loss: 13.3715, LR: 0.000001


Training: 100%|██████████| 76/76 [00:06<00:00, 10.99it/s, loss=5.3188, smooth loss=18.6]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.49it/s, loss=21.8392]


Epoch 79/2000, Train Loss: 18.5559, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.08it/s, loss=15.7492, smooth loss=18.6]
Validation: 100%|██████████| 76/76 [00:04<00:00, 17.11it/s, loss=21.8392]


Epoch 80/2000, Train Loss: 18.6120, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.54it/s, loss=7.4579, smooth loss=18.1]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.91it/s, loss=21.8392]


Epoch 81/2000, Train Loss: 18.1390, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.75it/s, loss=6.0599, smooth loss=18.2]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.80it/s, loss=21.8392]


Epoch 82/2000, Train Loss: 18.1655, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.32it/s, loss=7.7461, smooth loss=18.1]
Validation: 100%|██████████| 76/76 [00:04<00:00, 16.86it/s, loss=21.8392]


Epoch 83/2000, Train Loss: 18.1456, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.51it/s, loss=11.5000, smooth loss=18.4]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.45it/s, loss=21.8392]


Epoch 84/2000, Train Loss: 18.4396, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.54it/s, loss=5.2379, smooth loss=18.4]
Validation: 100%|██████████| 76/76 [00:04<00:00, 17.59it/s, loss=21.8392]


Epoch 85/2000, Train Loss: 18.3955, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:07<00:00, 10.65it/s, loss=2.2321, smooth loss=17.8]
Validation: 100%|██████████| 76/76 [00:04<00:00, 17.46it/s, loss=21.8392]


Epoch 86/2000, Train Loss: 17.8016, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.29it/s, loss=5.8281, smooth loss=19]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.08it/s, loss=21.8392]


Epoch 87/2000, Train Loss: 18.9588, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.27it/s, loss=6.3182, smooth loss=19.3]
Validation: 100%|██████████| 76/76 [00:04<00:00, 16.29it/s, loss=21.8392]


Epoch 88/2000, Train Loss: 19.2568, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.21it/s, loss=12.2130, smooth loss=18.3]
Validation: 100%|██████████| 76/76 [00:04<00:00, 17.90it/s, loss=21.8392]


Epoch 89/2000, Train Loss: 18.2652, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.53it/s, loss=4.6643, smooth loss=17.8]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.23it/s, loss=21.8392]


Epoch 90/2000, Train Loss: 17.7605, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.62it/s, loss=3.3974, smooth loss=17.1]
Validation: 100%|██████████| 76/76 [00:04<00:00, 17.40it/s, loss=21.8392]


Epoch 91/2000, Train Loss: 17.1175, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.26it/s, loss=2.9300, smooth loss=18.2]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.59it/s, loss=21.8392]


Epoch 92/2000, Train Loss: 18.2278, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.46it/s, loss=2.4008, smooth loss=19.1]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.60it/s, loss=21.8392]


Epoch 93/2000, Train Loss: 19.1387, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.54it/s, loss=10.5175, smooth loss=17.3]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.04it/s, loss=21.8392]


Epoch 94/2000, Train Loss: 17.3012, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.25it/s, loss=25.2874, smooth loss=19.5]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.86it/s, loss=21.8392]


Epoch 95/2000, Train Loss: 19.4691, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.65it/s, loss=3.2803, smooth loss=18.1]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.88it/s, loss=21.8392]


Epoch 96/2000, Train Loss: 18.0736, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.33it/s, loss=5.8048, smooth loss=18.5]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.57it/s, loss=21.8392]


Epoch 97/2000, Train Loss: 18.4502, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.65it/s, loss=6.1741, smooth loss=18.8]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.81it/s, loss=21.8392]


Epoch 98/2000, Train Loss: 18.7708, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.76it/s, loss=5.8608, smooth loss=18.7]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.57it/s, loss=21.8392]


Epoch 99/2000, Train Loss: 18.7003, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.25it/s, loss=4.1958, smooth loss=18.5]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.62it/s, loss=21.8392]


Epoch 100/2000, Train Loss: 18.5216, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 10.91it/s, loss=8.9523, smooth loss=18]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.70it/s, loss=21.8392]


Epoch 101/2000, Train Loss: 17.9720, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.66it/s, loss=12.1356, smooth loss=19]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.51it/s, loss=21.8392]


Epoch 102/2000, Train Loss: 19.0181, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.46it/s, loss=1.1288, smooth loss=17.4]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.43it/s, loss=21.8392]


Epoch 103/2000, Train Loss: 17.3676, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.01it/s, loss=9.0872, smooth loss=18]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.23it/s, loss=21.8392]


Epoch 104/2000, Train Loss: 18.0054, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.62it/s, loss=3.9245, smooth loss=19.1]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.47it/s, loss=21.8392]


Epoch 105/2000, Train Loss: 19.0591, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.57it/s, loss=10.0387, smooth loss=18.2]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.59it/s, loss=21.8392]


Epoch 106/2000, Train Loss: 18.1611, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:07<00:00, 10.67it/s, loss=12.5621, smooth loss=18.3]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.38it/s, loss=21.8392]


Epoch 107/2000, Train Loss: 18.2947, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.17it/s, loss=10.8958, smooth loss=19]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.48it/s, loss=21.8392]


Epoch 108/2000, Train Loss: 19.0410, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:07<00:00, 10.79it/s, loss=2.9379, smooth loss=18]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.37it/s, loss=21.8392]


Epoch 109/2000, Train Loss: 18.0257, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:07<00:00, 10.79it/s, loss=11.8962, smooth loss=18.9]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.33it/s, loss=21.8392]


Epoch 110/2000, Train Loss: 18.8805, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.22it/s, loss=2.8528, smooth loss=18.6]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.43it/s, loss=21.8392]


Epoch 111/2000, Train Loss: 18.6408, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.18it/s, loss=1.1098, smooth loss=18]
Validation: 100%|██████████| 76/76 [00:04<00:00, 16.94it/s, loss=21.8392]


Epoch 112/2000, Train Loss: 17.9887, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.54it/s, loss=14.6319, smooth loss=18.6]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.72it/s, loss=21.8392]


Epoch 113/2000, Train Loss: 18.6362, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.29it/s, loss=12.4097, smooth loss=18.6]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.54it/s, loss=21.8392]


Epoch 114/2000, Train Loss: 18.5734, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.52it/s, loss=3.7081, smooth loss=18.1]
Validation: 100%|██████████| 76/76 [00:04<00:00, 17.05it/s, loss=21.8392]


Epoch 115/2000, Train Loss: 18.1248, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.57it/s, loss=8.1935, smooth loss=18.6]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.72it/s, loss=21.8392]


Epoch 116/2000, Train Loss: 18.6273, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.22it/s, loss=3.2407, smooth loss=18.4]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.58it/s, loss=21.8392]


Epoch 117/2000, Train Loss: 18.3515, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.30it/s, loss=2.0948, smooth loss=17.9]
Validation: 100%|██████████| 76/76 [00:04<00:00, 17.18it/s, loss=21.8392]


Epoch 118/2000, Train Loss: 17.9250, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.90it/s, loss=6.5663, smooth loss=17.5]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.96it/s, loss=21.8392]


Epoch 119/2000, Train Loss: 17.5328, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.42it/s, loss=34.0577, smooth loss=17.5]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.64it/s, loss=21.8392]


Epoch 120/2000, Train Loss: 17.4985, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.32it/s, loss=11.2770, smooth loss=18.5]
Validation: 100%|██████████| 76/76 [00:04<00:00, 17.00it/s, loss=21.8392]


Epoch 121/2000, Train Loss: 18.5270, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.67it/s, loss=10.5080, smooth loss=19.5]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.63it/s, loss=21.8392]


Epoch 122/2000, Train Loss: 19.4897, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.29it/s, loss=1.2976, smooth loss=18.6]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.62it/s, loss=21.8392]


Epoch 123/2000, Train Loss: 18.5728, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.43it/s, loss=2.4253, smooth loss=19.3]
Validation: 100%|██████████| 76/76 [00:04<00:00, 17.35it/s, loss=21.8392]


Epoch 124/2000, Train Loss: 19.3417, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.47it/s, loss=11.7596, smooth loss=17.1]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.91it/s, loss=21.8392]


Epoch 125/2000, Train Loss: 17.1438, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.68it/s, loss=6.2929, smooth loss=17.6]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.29it/s, loss=21.8392]


Epoch 126/2000, Train Loss: 17.5810, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.53it/s, loss=8.5302, smooth loss=18.5]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.03it/s, loss=21.8392]


Epoch 127/2000, Train Loss: 18.4776, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.68it/s, loss=11.3402, smooth loss=18.8]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.34it/s, loss=21.8392]


Epoch 128/2000, Train Loss: 18.7688, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.41it/s, loss=49.9336, smooth loss=19.3]
Validation: 100%|██████████| 76/76 [00:04<00:00, 17.65it/s, loss=21.8392]


Epoch 129/2000, Train Loss: 19.3100, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 10.92it/s, loss=2.5661, smooth loss=17.7]
Validation: 100%|██████████| 76/76 [00:04<00:00, 17.66it/s, loss=21.8392]


Epoch 130/2000, Train Loss: 17.6528, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.02it/s, loss=8.7111, smooth loss=19.2]
Validation: 100%|██████████| 76/76 [00:04<00:00, 17.81it/s, loss=21.8392]


Epoch 131/2000, Train Loss: 19.2072, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.57it/s, loss=9.5819, smooth loss=18.4]
Validation: 100%|██████████| 76/76 [00:04<00:00, 16.76it/s, loss=21.8392]


Epoch 132/2000, Train Loss: 18.4171, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.28it/s, loss=33.8673, smooth loss=18.9]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.04it/s, loss=21.8392]


Epoch 133/2000, Train Loss: 18.9068, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.59it/s, loss=2.5367, smooth loss=18.1]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.41it/s, loss=21.8392]


Epoch 134/2000, Train Loss: 18.1350, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.18it/s, loss=4.6239, smooth loss=18.1]
Validation: 100%|██████████| 76/76 [00:04<00:00, 16.27it/s, loss=21.8392]


Epoch 135/2000, Train Loss: 18.1069, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.00it/s, loss=9.6144, smooth loss=18.4]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.54it/s, loss=21.8392]


Epoch 136/2000, Train Loss: 18.4094, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.65it/s, loss=7.2970, smooth loss=18.7]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.32it/s, loss=21.8392]


Epoch 137/2000, Train Loss: 18.7280, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.53it/s, loss=9.0461, smooth loss=18]
Validation: 100%|██████████| 76/76 [00:04<00:00, 16.91it/s, loss=21.8392]


Epoch 138/2000, Train Loss: 18.0189, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.02it/s, loss=16.7542, smooth loss=18.3]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.45it/s, loss=21.8392]


Epoch 139/2000, Train Loss: 18.2535, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.03it/s, loss=11.9609, smooth loss=18.6]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.54it/s, loss=21.8392]


Epoch 140/2000, Train Loss: 18.5634, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.55it/s, loss=0.4239, smooth loss=17.7]
Validation: 100%|██████████| 76/76 [00:04<00:00, 16.92it/s, loss=21.8392]


Epoch 141/2000, Train Loss: 17.6531, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.20it/s, loss=4.7721, smooth loss=18.4]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.41it/s, loss=21.8392]


Epoch 142/2000, Train Loss: 18.3844, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 10.95it/s, loss=12.7290, smooth loss=18.3]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.63it/s, loss=21.8392]


Epoch 143/2000, Train Loss: 18.2901, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.42it/s, loss=11.4677, smooth loss=18.9]
Validation: 100%|██████████| 76/76 [00:04<00:00, 15.57it/s, loss=21.8392]


Epoch 144/2000, Train Loss: 18.8604, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 10.88it/s, loss=2.0507, smooth loss=18.7]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.08it/s, loss=21.8392]


Epoch 145/2000, Train Loss: 18.6573, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.57it/s, loss=3.7702, smooth loss=18.3]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.33it/s, loss=21.8392]


Epoch 146/2000, Train Loss: 18.3458, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.28it/s, loss=1.0848, smooth loss=18.7]
Validation: 100%|██████████| 76/76 [00:04<00:00, 16.21it/s, loss=21.8392]


Epoch 147/2000, Train Loss: 18.6739, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.46it/s, loss=9.3160, smooth loss=18.3]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.36it/s, loss=21.8392]


Epoch 148/2000, Train Loss: 18.3014, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 10.92it/s, loss=13.6567, smooth loss=18.3]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.71it/s, loss=21.8392]


Epoch 149/2000, Train Loss: 18.3431, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.42it/s, loss=6.9316, smooth loss=18.6]
Validation: 100%|██████████| 76/76 [00:04<00:00, 16.99it/s, loss=21.8392]


Epoch 150/2000, Train Loss: 18.6160, Val Loss: 13.3715, LR: 0.000000


Training: 100%|██████████| 76/76 [00:06<00:00, 11.35it/s, loss=1.6645, smooth loss=17.9]
Validation: 100%|██████████| 76/76 [00:04<00:00, 18.11it/s, loss=21.8392]

Epoch 151/2000, Train Loss: 17.9398, Val Loss: 13.3715, LR: 0.000000
Early stopping triggered after 151 epochs
Training completed!





In [23]:
# Plot training and validation losses
plt.figure(figsize=(10, 6))
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('RMSD Loss')
plt.title('Training and Validation Losses')
plt.legend()
plt.grid(True)
plt.savefig('training_loss.png')
plt.close()

In [24]:
# Generate multiple conformations for each RNA sequence
def generate_multiple_conformations(model, data, num_conformations=5):
    """
    Generate multiple structural conformations for an RNA sequence.
    
    Args:
        model: The trained GNN model
        data: Graph data object containing the RNA sequence
        num_conformations: Number of conformations to generate (default: 5)
        
    Returns:
        List of numpy arrays, each array has shape (n_nucleotides, 3) for x,y,z coordinates
    """
    model.eval()
    conformations = []
    
    # Set random seed for reproducibility
    torch.manual_seed(42)
    
    with torch.no_grad():
        # Generate first conformation (deterministic)
        base_pred = model(data)
        base_np = base_pred.cpu().numpy()
        
        # Check if base prediction contains NaN values
        if np.isnan(base_np).any():
            print("Warning: Base prediction contains NaN values. Replacing with zeros.")
            base_np = np.nan_to_num(base_np, nan=0.0)
        
        # Save the base prediction
        conformations.append(base_np)
        
        # Generate additional conformations with controlled variations
        for i in range(1, num_conformations):
            # Use different seeds for different conformations
            torch.manual_seed(42 + i * 100)  # Larger seed increment for more diversity
            
            # Create a copy of the base prediction with a small, controlled variation
            variation = base_np.copy()
            
            # Add random noise with small magnitude (1-5% of the coordinate values)
            # Calculate standard deviation of base coordinates to scale noise appropriately
            if not np.all(base_np == 0):  # Check if base_np is not all zeros
                coord_std = max(np.std(base_np), 0.5)  # Use at least 0.5 to avoid too small noise
                noise_scale = coord_std * 0.05 * (i + 1)  # Increasing noise for each conformation
            else:
                # If base prediction is all zeros (which shouldn't happen normally)
                noise_scale = 0.5 * (i + 1)
            
            # Generate noise and ensure it's not NaN
            noise = np.random.normal(0, noise_scale, size=variation.shape)
            
            # Apply noise to create a new conformation
            variation += noise
            
            # Ensure no NaN values
            variation = np.nan_to_num(variation, nan=0.0)
            
            conformations.append(variation)
    
    # Double-check that all conformations are valid and contain no NaNs
    for i, conf in enumerate(conformations):
        if np.isnan(conf).any():
            print(f"Warning: Conformation {i+1} contains NaN values after processing. Replacing with zeros.")
            conformations[i] = np.nan_to_num(conf, nan=0.0)
    
    return conformations

# Function to make multiple predictions for test data
def predict_multiple_conformations(model, test_loader, device, num_conformations=5):
    predictions = {}
    
    for data in test_loader:
        data = data.to(device)
        conformations = generate_multiple_conformations(model, data, num_conformations)
        
        # Store predictions - ensure target_id is a hashable type (string)
        # The target_id could be stored as a list or other non-hashable type
        if hasattr(data, 'target_id'):
            # Convert to string if it's not already
            if isinstance(data.target_id, list) and len(data.target_id) > 0:
                target_id = str(data.target_id[0])  # Take the first element if it's a list
            else:
                target_id = str(data.target_id)  # Convert to string to ensure hashability
        else:
            # Generate a unique ID if none exists
            target_id = f"unknown_target_{len(predictions)}"
            
        print(f"Processing target: {target_id}")
        predictions[target_id] = conformations
        
        # If we have ground truth, report metrics for the first conformation
        if hasattr(data, 'y') and data.y is not None and len(conformations) > 0:
            first_conf = torch.tensor(conformations[0], device=device)
            
            if hasattr(data, 'mask') and data.mask is not None:
                loss = rmsd_loss(first_conf, data.y, data.mask).item()
            else:
                loss = rmsd_loss(first_conf, data.y).item()
                
            print(f"Prediction for {target_id}, RMSD of first conformation: {loss:.4f}")
    
    return predictions

# Example of how to use the prediction function on test data
def process_test_data(test_sequences_path):
    # Load test sequences
    test_sequences = pd.read_csv(test_sequences_path)
    
    # Create test dataset (without labels)
    test_dataset = create_dataset(test_sequences)
    
    # Create test loader
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    
    # Make predictions
    predictions = predict_multiple_conformations(model, test_loader, device)
    
    # Format predictions for submission
    formatted_predictions = []
    
    for target_id, conformations in predictions.items():
        for i, conformation in enumerate(conformations):
            for j, coords in enumerate(conformation):
                resid = j + 1  # 1-based indexing
                row = {
                    'ID': f"{target_id}_{resid}",
                    f'x_{i+1}': coords[0],
                    f'y_{i+1}': coords[1],
                    f'z_{i+1}': coords[2]
                }
                formatted_predictions.append(row)
    
    # Create submission dataframe
    submission_df = pd.DataFrame(formatted_predictions)
    return submission_df

In [25]:
test_predictions = process_test_data("/kaggle/input/stanford-rna-3d-folding/test_sequences.csv")
sub = pd.read_csv("/kaggle/input/stanford-rna-3d-folding/sample_submission.csv")
DF_ROWS = []

for i, row in sub.iterrows():
    snap = test_predictions[test_predictions['ID'] == row['ID']]
    x1, y1, z1, x2, y2, z2, x3, y3, z3, x4, y4, z4, x5, y5, z5 = snap['x_1'], snap['y_1'], snap['z_1'], snap['x_2'], snap['y_2'], snap['z_2'], snap['x_3'], snap['y_3'], snap['z_3'], snap['x_4'], snap['y_4'], snap['z_4'], snap['x_5'], snap['y_5'], snap['z_5']
    x1, y1, z1, x2, y2, z2, x3, y3, z3, x4, y4, z4, x5, y5, z5 = x1.values[0], y1.values[0], z1.values[0], x2.values[1], y2.values[1], z2.values[1], x3.values[2], y3.values[2], z3.values[2], x4.values[3], y4.values[3], z4.values[3], x5.values[4], y5.values[4], z5.values[4]
    _row = [x1, y1, z1, x2, y2, z2, x3, y3, z3, x4, y4, z4, x5, y5, z5]
    DF_ROWS.append(_row)
sub[['x_1', 'y_1', 'z_1', 'x_2', 'y_2', 'z_2', 'x_3', 'y_3', 'z_3', 'x_4', 'y_4', 'z_4', 'x_5', 'y_5', 'z_5']] = DF_ROWS
sub.head()
sub.to_csv("submission.csv", index=False)

100%|██████████| 12/12 [00:00<00:00, 126.35it/s]


Dataset creation: 0 sequences skipped due to non-standard nucleotides
Dataset creation: 0 sequences have >50% NaN coordinates
Processing target: R1107
Processing target: R1108
Processing target: R1116
Processing target: R1117v2
Processing target: R1126
Processing target: R1128
Processing target: R1136
Processing target: R1138
Processing target: R1149
Processing target: R1156
Processing target: R1189
Processing target: R1190
