## Step 1: Imports

In [1]:
import pandas as pd
import torch
import matplotlib.pyplot as plt
import numpy as np
import torch
import random
import pickle
from tqdm import tqdm

In [2]:
#set seed for everything
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

## Step 2: Defining Config

In [3]:
config = {
    "seed"                         : 0,
    "cutoff_date"                  : "2020-01-01",
    "test_cutoff_date"             : "2022-05-01",
    "max_len"                      : 384,
    "batch_size"                   : 1,
    "learning_rate"                : 1e-4,
    "weight_decay"                 : 0.0,
    "mixed_precision"              : "bf16",
    "model_config_path"            : "../working/configs/pairwise.yaml",  # Adjust path as needed
    "epochs"                       : 10,
    "cos_epoch"                    : 5,
    "loss_power_scale"             : 1.0,
    "max_cycles"                   : 1,
    "grad_clip"                    : 0.1,
    "gradient_accumulation_steps"  : 1,
    "d_clamp"                      : 30,
    "max_len_filter"               : 9999999,
    "min_len_filter"               : 10, 
    "structural_violation_epoch"   : 50,
    "balance_weight"               : False,
}

## Step 3 : Get data and do some data processing¶


In [4]:
valid_sequences = pd.read_csv("/kaggle/input/stanford-rna-3d-folding/train_sequences.csv")
valid_labels    = pd.read_csv("/kaggle/input/stanford-rna-3d-folding/train_labels.csv")

valid_sequences.drop(index=[610, 639, 687, 782], inplace=True)
valid_labels.drop(index=[610, 639, 687, 782], inplace=True)

valid_sequences = valid_sequences.reset_index()
valid_labels = valid_labels.reset_index()

In [5]:
valid_labels["pdb_id"] = valid_labels["ID"].apply(lambda x: x.split("_")[0]+'_'+x.split("_")[1])
valid_labels["pdb_id"] 

0         1SCL_A
1         1SCL_A
2         1SCL_A
3         1SCL_A
4         1SCL_A
           ...  
137086    8Z1F_T
137087    8Z1F_T
137088    8Z1F_T
137089    8Z1F_T
137090    8Z1F_T
Name: pdb_id, Length: 137091, dtype: object

### Getting the trainig data ready

In [6]:
import pandas as pd
import os
from tqdm import tqdm

Test = False
# Assuming your CSV files are in a folder
folder_path = '/kaggle/input/rna-dataset-2/extracted'  # Replace with your actual folder path
train_sequences = pd.DataFrame(columns=['target_id', 'sequence', 'temporal_cutoff', 'description', 'all_sequences'])
train_labels    = pd.DataFrame(columns=['ID', 'resname','resid','x_1','y_1','z_1','pdb_id'])

# Process each CSV file in the folder
i = 0
for file in tqdm(os.listdir(folder_path)):
    i += 1
    if Test:
        if i == 100:
            break
    if file.endswith('.csv'):
        file_path = os.path.join(folder_path, file)
        df = pd.read_csv(file_path)
        
        # Extract target_id from the first row
        target_id = df['target_id'].iloc[0]
        
        # Create sequence by concatenating all resname values
        sequence = ''.join(df['resname'])
        
        # Add to result dataframe
        new_row = pd.DataFrame({
            'target_id': [target_id],
            'sequence': [sequence],
            'temporal_cutoff': [''],
            'description': [''],
            'all_sequences': ['']
        })
        
        train_sequences = pd.concat([train_sequences, new_row], ignore_index=True)

        # Create labels dataframe entries
        labels_rows = []
        for _, row in df.iterrows():
            # Create ID in format target_id_resid
            id_value = f"{target_id}_{row['resid']}"
            
            labels_rows.append({
                'ID': id_value,
                'resname': row['resname'],
                'resid': row['resid'],
                'x_1': row['x_1'],
                'y_1': row['y_1'],
                'z_1': row['z_1'],
                'pdb_id': target_id
            })
        
        # Add to labels dataframe
        train_labels = pd.concat([train_labels, pd.DataFrame(labels_rows)], ignore_index=True)

# Save the result to a new CSV file
train_sequences.to_csv('converted_data.csv', index=False)
train_labels.to_csv('converted_labels.csv', index=False)

  train_labels = pd.concat([train_labels, pd.DataFrame(labels_rows)], ignore_index=True)
100%|██████████| 3502/3502 [01:13<00:00, 47.42it/s]


In [7]:
all_xyz=[]

for pdb_id in tqdm(train_sequences['target_id']):
    df = train_labels[train_labels["pdb_id"]==pdb_id]
    xyz=df[['x_1','y_1','z_1']].to_numpy().astype('float32')
    xyz[xyz<-1e17]=float('Nan');
    all_xyz.append(xyz)
df

  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xy

Unnamed: 0,ID,resname,resid,x_1,y_1,z_1,pdb_id
261583,1ykq_B_1,G,1,-20.224393,-16.491967,10.122993,1ykq_B
261584,1ykq_B_2,G,2,-19.009405,-18.617505,3.605266,1ykq_B
261585,1ykq_B_3,G,3,-18.74984,-16.821151,-1.208829,1ykq_B
261586,1ykq_B_4,C,4,-19.168544,-14.029909,-5.162431,1ykq_B
261587,1ykq_B_5,G,5,-21.532162,-7.10573,-5.407534,1ykq_B
261588,1ykq_B_6,A,6,-20.545637,-2.526273,-3.256684,1ykq_B
261589,1ykq_B_7,G,7,-17.60603,1.6246,0.14767,1ykq_B
261590,1ykq_B_8,G,8,-17.396982,8.433533,-4.715955,1ykq_B
261591,1ykq_B_9,C,9,-13.165123,12.74604,-5.572929,1ykq_B
261592,1ykq_B_10,C,10,-8.009336,13.896392,-3.60407,1ykq_B


In [8]:
filter_nan = []
max_len    = 0

for xyz in all_xyz:
    if len(xyz) > max_len:
        max_len = len(xyz)
    filter_nan.append((np.isnan(xyz).mean() <= 0.5) & \
                      (len(xyz)<config['max_len_filter']) & \
                      (len(xyz)>config['min_len_filter']))
print(f"Longest sequence in train: {max_len}")

filter_nan      = np.array(filter_nan)
non_nan_indices = np.arange(len(filter_nan))[filter_nan]

train_sequences = train_sequences.loc[non_nan_indices].reset_index(drop=True)
all_xyz         = [all_xyz[i] for i in non_nan_indices]

Longest sequence in train: 960


In [9]:
all_xyz_valid=[]

for pdb_id in tqdm(valid_sequences['target_id']):
    df = valid_labels[valid_labels["pdb_id"]==pdb_id]
    xyz=df[['x_1','y_1','z_1']].to_numpy().astype('float32')
    xyz[xyz<-1e17]=float('Nan');
    all_xyz_valid.append(xyz)

  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xy

In [10]:
filter_nan_valid = []
max_len_valid    = 0

for xyz in all_xyz_valid:
    if len(xyz) > max_len_valid:
        max_len_valid = len(xyz)
    filter_nan_valid.append((np.isnan(xyz).mean() <= 0.5) & \
                      (len(xyz)<config['max_len_filter']) & \
                      (len(xyz)>config['min_len_filter']))
print(f"Longest sequence in train: {max_len_valid}")

filter_nan_valid      = np.array(filter_nan_valid)
non_nan_indices_valid = np.arange(len(filter_nan_valid))[filter_nan_valid]

valid_sequences = valid_sequences.loc[non_nan_indices_valid].reset_index(drop=True)
all_xyz_valid   = [all_xyz_valid[i] for i in non_nan_indices_valid]

Longest sequence in train: 4298


In [11]:
#pack data into a dictionary

data={
      "sequence":train_sequences['sequence'].to_list(),
      "temporal_cutoff": train_sequences['temporal_cutoff'].to_list(),
      "description": train_sequences['description'].to_list(),
      "all_sequences": train_sequences['all_sequences'].to_list(),
      "xyz": all_xyz
}

In [12]:
#pack data into a dictionary

valid_data={
      "sequence":valid_sequences['sequence'].to_list(),
      "temporal_cutoff": valid_sequences['temporal_cutoff'].to_list(),
      "description": valid_sequences['description'].to_list(),
      "all_sequences": valid_sequences['all_sequences'].to_list(),
      "xyz": all_xyz_valid
}

## Step 4: Split train data into train/val/test¶
We will simply do a temporal split, because that's how testing is done in structural biology in general (in actual blind tests)

In [13]:
# # Split data into train and test
# all_index        = np.arange(len(data['sequence']))
# cutoff_date      = pd.Timestamp(config['cutoff_date'])
# test_cutoff_date = pd.Timestamp(config['test_cutoff_date'])
# train_index      = [i for i, d in enumerate(data['temporal_cutoff']) if pd.Timestamp(d) <= cutoff_date]
# test_index       = [i for i, d in enumerate(data['temporal_cutoff']) if pd.Timestamp(d) > cutoff_date and pd.Timestamp(d) <= test_cutoff_date]

In [14]:
print(f"Train size: {len(train_sequences)}")
print(f"Test size: {len(valid_sequences)}")

Train size: 3412
Test size: 764


## Step 5: Get pytorch dataset¶

In [15]:
from torch.utils.data import Dataset, DataLoader
from ast import literal_eval

def get_ct(bp,s):
    ct_matrix=np.zeros((len(s),len(s)))
    for b in bp:
        ct_matrix[b[0]-1,b[1]-1]=1
    return ct_matrix

class RNA3D_Dataset(Dataset):
    def __init__(self,indices,data):
        self.indices  = indices
        self.data     = data
        self.tokens   = {nt:i for i,nt in enumerate('ACGU')}

    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):

        idx      = self.indices[idx]
        sequence = [self.tokens[nt] for nt in (self.data['sequence'][idx])]
        sequence = np.array(sequence)
        sequence = torch.tensor(sequence)

        #get C1' xyz
        xyz      = self.data['xyz'][idx]
        xyz      = torch.tensor(np.array(xyz))


        if len(sequence)>config['max_len']:
            crop_start  = np.random.randint(len(sequence)-config['max_len'])
            crop_end    = crop_start+config['max_len']
            sequence    = sequence[crop_start:crop_end]
            xyz         = xyz[crop_start:crop_end]
            
        return {'sequence' : sequence, 'xyz' : xyz}

In [16]:
train_dataset = RNA3D_Dataset(train_sequences.index,data)
val_dataset   = RNA3D_Dataset(valid_sequences.index,valid_data)

In [17]:
# for i in range(len(val_dataset)):
#     try:
#         length = len(val_dataset[i]['sequence'])
#     except Exception as e:
#         print(i,e)

In [18]:
def filter_mismatched_samples(dataset):
    valid_indices = []
    
    for i in range(len(dataset)):
        sample = dataset[i]
        # Check if the first dimension of sequence matches the first dimension of xyz
        if sample['sequence'].shape[0] == sample['xyz'].shape[0]:
            valid_indices.append(i)
        else:
            print(f"Removing sample {i}: sequence shape {sample['sequence'].shape}, xyz shape {sample['xyz'].shape}")
    
    # Create a filtered dataset using a subset of the original dataset
    from torch.utils.data import Subset
    filtered_dataset = Subset(dataset, valid_indices)
    
    print(f"Original dataset size: {len(dataset)}")
    print(f"Filtered dataset size: {len(filtered_dataset)}")
    
    return filtered_dataset

val_dataset = filter_mismatched_samples(val_dataset)


Removing sample 26: sequence shape torch.Size([13]), xyz shape torch.Size([12, 3])
Removing sample 27: sequence shape torch.Size([29]), xyz shape torch.Size([28, 3])
Removing sample 30: sequence shape torch.Size([19]), xyz shape torch.Size([18, 3])
Removing sample 34: sequence shape torch.Size([44]), xyz shape torch.Size([43, 3])
Original dataset size: 764
Filtered dataset size: 760


In [19]:
train_loader = DataLoader(train_dataset,batch_size=1,shuffle=True)
val_loader   = DataLoader(val_dataset,batch_size=1,shuffle=False)

## Step 6: Get RibonanzaNet¶
We will add a linear layer to predict xyz of C1' atoms

In [20]:
! pip install einops



In [21]:
import sys

sys.path.append("/kaggle/input/ribonanzanet2d-final")


from Network import *
import yaml

class Config:
    def __init__(self, **entries):
        self.__dict__.update(entries)
        self.entries=entries

    def print(self):
        print(self.entries)

def load_config_from_yaml(file_path):
    with open(file_path, 'r') as file:
        config = yaml.safe_load(file)
    return Config(**config)

class finetuned_RibonanzaNet(RibonanzaNet):
    def __init__(self, config, pretrained=False):
        config.dropout=0.1
        super(finetuned_RibonanzaNet, self).__init__(config)
        #self.ct_predictor=nn.Linear(64,1)
        if pretrained:
            self.load_state_dict(torch.load("/kaggle/input/ribonanzanet-weights/RibonanzaNet.pt",map_location='cpu'))
        
        self.dropout=nn.Dropout(0.0)
        self.xyz_predictor=nn.Linear(256,3)

    def forward(self,src):
        
        #with torch.no_grad():
        sequence_features, pairwise_features=self.get_embeddings(src, torch.ones_like(src).long().to(src.device))
        xyz=self.xyz_predictor(sequence_features)

        return xyz

In [22]:
model = finetuned_RibonanzaNet(load_config_from_yaml("/kaggle/input/ribonanzanet2d-final/configs/pairwise.yaml"),pretrained=True).cuda()

constructing 9 ConvTransformerEncoderLayers


  self.load_state_dict(torch.load("/kaggle/input/ribonanzanet-weights/RibonanzaNet.pt",map_location='cpu'))


## Step 7: Training loop¶
we will use dRMSD loss on the predicted xyz. the loss function is invariant to translations, rotations, and reflections. because dRMSD is invariant to reflections, it cannot distinguish chiral structures, so there may be better loss functions

In [23]:
def calculate_distance_matrix(X,Y,epsilon=1e-4):
    return (torch.square(X[:,None]-Y[None,:])+epsilon).sum(-1).sqrt()


def dRMSD(pred_x,
          pred_y,
          gt_x,
          gt_y,
          epsilon=1e-4,Z=10,d_clamp=None):
    pred_dm=calculate_distance_matrix(pred_x,pred_y)
    gt_dm=calculate_distance_matrix(gt_x,gt_y)

    mask=~torch.isnan(gt_dm)
    mask[torch.eye(mask.shape[0]).bool()]=False

    if d_clamp is not None:
        rmsd=(torch.square(pred_dm[mask]-gt_dm[mask])+epsilon).clip(0,d_clamp**2)
    else:
        rmsd=torch.square(pred_dm[mask]-gt_dm[mask])+epsilon

    return rmsd.sqrt().mean()/Z

def local_dRMSD(pred_x,
          pred_y,
          gt_x,
          gt_y,
          epsilon=1e-4,Z=10,d_clamp=30):
    pred_dm=calculate_distance_matrix(pred_x,pred_y)
    gt_dm=calculate_distance_matrix(gt_x,gt_y)



    mask=(~torch.isnan(gt_dm))*(gt_dm<d_clamp)
    mask[torch.eye(mask.shape[0]).bool()]=False



    rmsd=torch.square(pred_dm[mask]-gt_dm[mask])+epsilon
    # rmsd=(torch.square(pred_dm[mask]-gt_dm[mask])+epsilon).sqrt()/Z
    #rmsd=torch.abs(pred_dm[mask]-gt_dm[mask])/Z
    return rmsd.sqrt().mean()/Z

def dRMAE(pred_x,
          pred_y,
          gt_x,
          gt_y,
          epsilon=1e-4,Z=10,d_clamp=None):
    flag = False
    pred_dm=calculate_distance_matrix(pred_x,pred_y)
    gt_dm=calculate_distance_matrix(gt_x,gt_y)


    if pred_dm.shape[0] != gt_dm.shape[0]:
        print(f"Got it")
        print(pred_dm.shape)
        print(gt_dm.shape)
        flag = True
        
    mask=~torch.isnan(gt_dm)
    
    mask[torch.eye(mask.shape[0]).bool()]=False

    rmsd=torch.abs(pred_dm[mask]-gt_dm[mask])

    return rmsd.mean()/Z, flag

import torch

def align_svd_mae(input, target, Z=10):
    """
    Aligns the input (Nx3) to target (Nx3) using SVD-based Procrustes alignment
    and computes RMSD loss.
    
    Args:
        input (torch.Tensor): Nx3 tensor representing the input points.
        target (torch.Tensor): Nx3 tensor representing the target points.
    
    Returns:
        aligned_input (torch.Tensor): Nx3 aligned input.
        rmsd_loss (torch.Tensor): RMSD loss.
    """
    assert input.shape == target.shape, "Input and target must have the same shape"

    #mask 
    mask=~torch.isnan(target.sum(-1))

    input=input[mask]
    target=target[mask]
    
    # Compute centroids
    centroid_input = input.mean(dim=0, keepdim=True)
    centroid_target = target.mean(dim=0, keepdim=True)

    # Center the points
    input_centered = input - centroid_input.detach()
    target_centered = target - centroid_target

    # Compute covariance matrix
    cov_matrix = input_centered.T @ target_centered

    # SVD to find optimal rotation
    U, S, Vt = torch.svd(cov_matrix)

    # Compute rotation matrix
    R = Vt @ U.T

    # Ensure a proper rotation (det(R) = 1, no reflection)
    if torch.det(R) < 0:
        Vt[-1, :] *= -1
        R = Vt @ U.T

    # Rotate input
    aligned_input = (input_centered @ R.T.detach()) + centroid_target.detach()
    
    return torch.abs(aligned_input-target).mean()/Z

In [24]:
from tqdm import tqdm
from torch.amp import GradScaler

if Test:
    epochs=1
    cos_epoch=1
else:
    epochs=50
    cos_epoch=35


best_loss=np.inf
optimizer = torch.optim.Adam(model.parameters(), weight_decay=0.0, lr=0.0001) #no weight decay following AF

batch_size=1

#for cycle in range(2):

criterion=torch.nn.BCEWithLogitsLoss(reduction='none')

scaler = GradScaler()

schedule=torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(epochs-cos_epoch)*len(train_loader)//batch_size)

best_val_loss=99999999999
for epoch in range(epochs):
    model.train()
    tbar=tqdm(train_loader)
    total_loss=0
    oom=0
    for idx, batch in enumerate(tbar):
        #try:
        sequence=batch['sequence'].cuda()
        gt_xyz=batch['xyz'].cuda().squeeze()

        #with torch.autocast(device_type='cuda', dtype=torch.float16):
        pred_xyz=model(sequence).squeeze()
        
        loss1, flag=dRMAE(pred_xyz,pred_xyz,gt_xyz,gt_xyz) 
        loss = loss1 + align_svd_mae(pred_xyz, gt_xyz)
             #local_dRMSD(pred_xyz,pred_xyz,gt_xyz,gt_xyz)

        if flag:
            print(f"Error Sequence : {sequence, gt_xyz}")
            
        if loss!=loss:
            stop

        
        (loss/batch_size).backward()

        if (idx+1)%batch_size==0 or idx+1 == len(tbar):

            torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            optimizer.step()
            optimizer.zero_grad()
            
            if (epoch+1)>cos_epoch:
                schedule.step()
      
        total_loss+=loss.item()
        
        tbar.set_description(f"Epoch {epoch + 1} Loss: {total_loss/(idx+1)} OOMs: {oom}")

    tbar=tqdm(val_loader)
    model.eval()
    val_preds=[]
    val_loss=0
    for idx, batch in enumerate(tbar):
        sequence=batch['sequence'].cuda()
        gt_xyz=batch['xyz'].cuda().squeeze()

        with torch.no_grad():
            pred_xyz=model(sequence).squeeze()
            loss, flag=dRMAE(pred_xyz,pred_xyz,gt_xyz,gt_xyz)
            
        val_loss+=loss.item()
        val_preds.append([gt_xyz.cpu().numpy(),pred_xyz.cpu().numpy()])
    val_loss=val_loss/len(tbar)
    print(f"val loss: {val_loss}")
    
    
    
    if val_loss<best_val_loss:
        best_val_loss=val_loss
        best_preds=val_preds
        torch.save(model.state_dict(),'RibonanzaNet-3D.pt')

    # 1.053595052265986 train loss after epoch 0
torch.save(model.state_dict(),'RibonanzaNet-3D-final.pt')

  return fn(*args, **kwargs)
Epoch 1 Loss: 2.63013479079514 OOMs: 0: 100%|██████████| 3412/3412 [10:07<00:00,  5.61it/s]
100%|██████████| 760/760 [00:45<00:00, 16.83it/s]


val loss: 1.0122510364377184


Epoch 2 Loss: 1.6991498854976466 OOMs: 0: 100%|██████████| 3412/3412 [10:05<00:00,  5.64it/s]
100%|██████████| 760/760 [00:45<00:00, 16.80it/s]


val loss: 0.825936627623282


Epoch 3 Loss: 1.448760587394517 OOMs: 0: 100%|██████████| 3412/3412 [10:07<00:00,  5.61it/s]
100%|██████████| 760/760 [00:45<00:00, 16.77it/s]


val loss: 0.7578543737236606


Epoch 4 Loss: 1.2737084365551385 OOMs: 0: 100%|██████████| 3412/3412 [10:06<00:00,  5.63it/s]
100%|██████████| 760/760 [00:45<00:00, 16.82it/s]


val loss: 0.6605237421040472


Epoch 5 Loss: 1.1466113532580007 OOMs: 0: 100%|██████████| 3412/3412 [10:09<00:00,  5.60it/s]
100%|██████████| 760/760 [00:45<00:00, 16.80it/s]


val loss: 0.6187622575089335


Epoch 6 Loss: 1.0557277673005987 OOMs: 0: 100%|██████████| 3412/3412 [10:05<00:00,  5.63it/s]
100%|██████████| 760/760 [00:45<00:00, 16.74it/s]


val loss: 0.603022581897676


Epoch 7 Loss: 0.9802063234098002 OOMs: 0: 100%|██████████| 3412/3412 [10:07<00:00,  5.62it/s]
100%|██████████| 760/760 [00:45<00:00, 16.79it/s]


val loss: 0.604051398809411


Epoch 8 Loss: 0.9211117105486455 OOMs: 0: 100%|██████████| 3412/3412 [10:05<00:00,  5.63it/s]
100%|██████████| 760/760 [00:44<00:00, 16.92it/s]


val loss: 0.5748598375014569


Epoch 9 Loss: 0.874432262841171 OOMs: 0: 100%|██████████| 3412/3412 [09:52<00:00,  5.76it/s]
100%|██████████| 760/760 [00:44<00:00, 17.04it/s]


val loss: 0.569460204665206


Epoch 10 Loss: 0.8308126990708489 OOMs: 0: 100%|██████████| 3412/3412 [09:55<00:00,  5.73it/s]
100%|██████████| 760/760 [00:45<00:00, 16.86it/s]


val loss: 0.5445793320277804


Epoch 11 Loss: 0.798288127951158 OOMs: 0: 100%|██████████| 3412/3412 [10:01<00:00,  5.68it/s]
100%|██████████| 760/760 [00:44<00:00, 17.06it/s]


val loss: 0.5233549134217594


Epoch 12 Loss: 0.765400683837053 OOMs: 0: 100%|██████████| 3412/3412 [09:56<00:00,  5.72it/s]
100%|██████████| 760/760 [00:45<00:00, 16.83it/s]


val loss: 0.5261005681890406


Epoch 13 Loss: 0.7343827488309194 OOMs: 0: 100%|██████████| 3412/3412 [10:02<00:00,  5.66it/s]
100%|██████████| 760/760 [00:45<00:00, 16.87it/s]


val loss: 0.5093606946989894


Epoch 14 Loss: 0.7113317568619394 OOMs: 0: 100%|██████████| 3412/3412 [10:02<00:00,  5.66it/s]
100%|██████████| 760/760 [00:45<00:00, 16.73it/s]


val loss: 0.5247547588017034


Epoch 15 Loss: 0.6863622933075898 OOMs: 0: 100%|██████████| 3412/3412 [10:00<00:00,  5.68it/s]
100%|██████████| 760/760 [00:44<00:00, 16.94it/s]


val loss: 0.5137224103275099


Epoch 16 Loss: 0.6656217974548393 OOMs: 0: 100%|██████████| 3412/3412 [10:00<00:00,  5.68it/s]
100%|██████████| 760/760 [00:45<00:00, 16.84it/s]


val loss: 0.4957782459445298


Epoch 17 Loss: 0.6508533743484184 OOMs: 0: 100%|██████████| 3412/3412 [10:01<00:00,  5.67it/s]
100%|██████████| 760/760 [00:45<00:00, 16.61it/s]


val loss: 0.5024764935044865


Epoch 18 Loss: 0.6288595841815401 OOMs: 0: 100%|██████████| 3412/3412 [09:58<00:00,  5.70it/s]
100%|██████████| 760/760 [00:44<00:00, 16.95it/s]


val loss: 0.5108312939097615


Epoch 19 Loss: 0.6152396792412712 OOMs: 0: 100%|██████████| 3412/3412 [09:58<00:00,  5.70it/s]
100%|██████████| 760/760 [00:44<00:00, 16.94it/s]


val loss: 0.4982888803474213


Epoch 20 Loss: 0.598826937200578 OOMs: 0: 100%|██████████| 3412/3412 [10:03<00:00,  5.66it/s]
100%|██████████| 760/760 [00:45<00:00, 16.76it/s]


val loss: 0.5126010220693914


Epoch 21 Loss: 0.5872691464495756 OOMs: 0: 100%|██████████| 3412/3412 [10:08<00:00,  5.61it/s]
100%|██████████| 760/760 [00:45<00:00, 16.76it/s]


val loss: 0.48945972292536966


Epoch 22 Loss: 0.5727332368021414 OOMs: 0: 100%|██████████| 3412/3412 [10:09<00:00,  5.60it/s]
100%|██████████| 760/760 [00:45<00:00, 16.80it/s]


val loss: 0.48521986914994686


Epoch 23 Loss: 0.5575436443780469 OOMs: 0: 100%|██████████| 3412/3412 [10:03<00:00,  5.65it/s]
100%|██████████| 760/760 [00:44<00:00, 16.92it/s]


val loss: 0.4862393673509359


Epoch 24 Loss: 0.550487911221675 OOMs: 0: 100%|██████████| 3412/3412 [10:02<00:00,  5.67it/s]
100%|██████████| 760/760 [00:45<00:00, 16.77it/s]


val loss: 0.4562255154608896


Epoch 25 Loss: 0.5383335867807146 OOMs: 0: 100%|██████████| 3412/3412 [10:04<00:00,  5.64it/s]
100%|██████████| 760/760 [00:45<00:00, 16.80it/s]


val loss: 0.47107328176498414


Epoch 26 Loss: 0.5269554027302709 OOMs: 0: 100%|██████████| 3412/3412 [10:03<00:00,  5.66it/s]
100%|██████████| 760/760 [00:45<00:00, 16.89it/s]


val loss: 0.45454770509937875


Epoch 27 Loss: 0.5198747554896584 OOMs: 0: 100%|██████████| 3412/3412 [10:12<00:00,  5.57it/s]
100%|██████████| 760/760 [00:45<00:00, 16.65it/s]


val loss: 0.4510059370659292


Epoch 28 Loss: 0.5089436021086444 OOMs: 0: 100%|██████████| 3412/3412 [10:13<00:00,  5.56it/s]
100%|██████████| 760/760 [00:45<00:00, 16.77it/s]


val loss: 0.45773747229066336


Epoch 29 Loss: 0.5010320884344831 OOMs: 0: 100%|██████████| 3412/3412 [10:09<00:00,  5.60it/s]
100%|██████████| 760/760 [00:45<00:00, 16.68it/s]


val loss: 0.4582272725866029


Epoch 30 Loss: 0.49542009605996634 OOMs: 0: 100%|██████████| 3412/3412 [10:06<00:00,  5.63it/s]
100%|██████████| 760/760 [00:45<00:00, 16.76it/s]


val loss: 0.44534884708394346


Epoch 31 Loss: 0.48246341187327857 OOMs: 0: 100%|██████████| 3412/3412 [10:20<00:00,  5.50it/s]
100%|██████████| 760/760 [00:46<00:00, 16.31it/s]


val loss: 0.4643015937340495


Epoch 32 Loss: 0.4787195046505468 OOMs: 0: 100%|██████████| 3412/3412 [11:14<00:00,  5.06it/s]
100%|██████████| 760/760 [00:49<00:00, 15.46it/s]


val loss: 0.43284502307835376


Epoch 33 Loss: 0.47286443001779555 OOMs: 0: 100%|██████████| 3412/3412 [10:38<00:00,  5.34it/s]
100%|██████████| 760/760 [00:45<00:00, 16.75it/s]


val loss: 0.43618874203805863


Epoch 34 Loss: 0.4672221792195809 OOMs: 0: 100%|██████████| 3412/3412 [10:11<00:00,  5.58it/s]
100%|██████████| 760/760 [00:45<00:00, 16.78it/s]


val loss: 0.4438545821381635


Epoch 35 Loss: 0.4594363097083981 OOMs: 0: 100%|██████████| 3412/3412 [10:06<00:00,  5.62it/s]
100%|██████████| 760/760 [00:45<00:00, 16.83it/s]


val loss: 0.44452151044909105


Epoch 36 Loss: 0.453960042785316 OOMs: 0: 100%|██████████| 3412/3412 [10:05<00:00,  5.64it/s]
100%|██████████| 760/760 [00:45<00:00, 16.78it/s]


val loss: 0.4453182041350948


Epoch 37 Loss: 0.4468786799068293 OOMs: 0: 100%|██████████| 3412/3412 [09:59<00:00,  5.69it/s]
100%|██████████| 760/760 [00:44<00:00, 16.98it/s]


val loss: 0.4363552753225361


Epoch 38 Loss: 0.44097986889420054 OOMs: 0: 100%|██████████| 3412/3412 [09:51<00:00,  5.77it/s]
100%|██████████| 760/760 [00:44<00:00, 17.01it/s]


val loss: 0.4191904941524722


Epoch 39 Loss: 0.4261215449838188 OOMs: 0: 100%|██████████| 3412/3412 [09:53<00:00,  5.75it/s]
100%|██████████| 760/760 [00:44<00:00, 17.01it/s]


val loss: 0.4269484042042964


Epoch 40 Loss: 0.41709432092572785 OOMs: 0: 100%|██████████| 3412/3412 [10:01<00:00,  5.68it/s]
100%|██████████| 760/760 [00:44<00:00, 16.92it/s]


val loss: 0.4197868126150417


Epoch 41 Loss: 0.40567801917978463 OOMs: 0: 100%|██████████| 3412/3412 [10:07<00:00,  5.62it/s]
100%|██████████| 760/760 [00:45<00:00, 16.76it/s]


val loss: 0.43316404010218224


Epoch 42 Loss: 0.3920990620408937 OOMs: 0: 100%|██████████| 3412/3412 [10:01<00:00,  5.67it/s]
100%|██████████| 760/760 [00:45<00:00, 16.72it/s]


val loss: 0.4166604302308865


Epoch 43 Loss: 0.3784568996757164 OOMs: 0: 100%|██████████| 3412/3412 [09:53<00:00,  5.75it/s]
100%|██████████| 760/760 [00:44<00:00, 17.03it/s]


val loss: 0.4123418648261577


Epoch 44 Loss: 0.3677389472297197 OOMs: 0: 100%|██████████| 3412/3412 [09:52<00:00,  5.75it/s]
100%|██████████| 760/760 [00:44<00:00, 17.05it/s]


val loss: 0.42210208220946555


Epoch 45 Loss: 0.3587652689745652 OOMs: 0: 100%|██████████| 3412/3412 [09:55<00:00,  5.73it/s]
100%|██████████| 760/760 [00:45<00:00, 16.84it/s]


val loss: 0.4158411978891021


Epoch 46 Loss: 0.3492457014404822 OOMs: 0: 100%|██████████| 3412/3412 [09:56<00:00,  5.72it/s]
100%|██████████| 760/760 [00:45<00:00, 16.80it/s]


val loss: 0.41182626708361664


Epoch 47 Loss: 0.34264299505611867 OOMs: 0: 100%|██████████| 3412/3412 [10:04<00:00,  5.64it/s]
100%|██████████| 760/760 [00:44<00:00, 17.06it/s]


val loss: 0.418138586501836


Epoch 48 Loss: 0.3376249282512757 OOMs: 0: 100%|██████████| 3412/3412 [10:09<00:00,  5.60it/s]
100%|██████████| 760/760 [00:45<00:00, 16.65it/s]


val loss: 0.40950466965227145


Epoch 49 Loss: 0.3346729422531856 OOMs: 0: 100%|██████████| 3412/3412 [10:14<00:00,  5.55it/s]
100%|██████████| 760/760 [00:45<00:00, 16.67it/s]


val loss: 0.4032145107017928


Epoch 50 Loss: 0.33367977603169424 OOMs: 0: 100%|██████████| 3412/3412 [10:14<00:00,  5.55it/s]
100%|██████████| 760/760 [00:45<00:00, 16.69it/s]


val loss: 0.40128159113963574
