In [1]:
import pandas as pd
import torch
import matplotlib.pyplot as plt
import numpy as np
import torch
import random
import pickle
from tqdm import tqdm

In [2]:
#set seed for everything
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

# Config

In [3]:
config = {
    "seed": 0,
    "cutoff_date": "2020-01-01",
    "test_cutoff_date": "2022-05-01",
    "max_len": 384,
    "batch_size": 1,
    "learning_rate": 1e-4,
    "weight_decay": 0.0,
    "mixed_precision": "bf16",
    "model_config_path": "../working/configs/pairwise.yaml",  # Adjust path as needed
    "epochs": 10,
    "cos_epoch": 5,
    "loss_power_scale": 1.0,
    "max_cycles": 1,
    "grad_clip": 0.1,
    "gradient_accumulation_steps": 1,
    "d_clamp": 30,
    "max_len_filter": 9999999,
    "min_len_filter": 10, 
    "structural_violation_epoch": 50,
    "balance_weight": False,
}

# Get data and do some data processing¶


In [4]:
# Load data

train_sequences=pd.read_csv("/kaggle/input/stanford-rna-3d-folding/train_sequences.csv")
train_labels=pd.read_csv("/kaggle/input/stanford-rna-3d-folding/train_labels.csv")

In [5]:
train_labels["pdb_id"] = train_labels["ID"].apply(lambda x: x.split("_")[0]+'_'+x.split("_")[1])
train_labels["pdb_id"] 

0         1SCL_A
1         1SCL_A
2         1SCL_A
3         1SCL_A
4         1SCL_A
           ...  
137090    8Z1F_T
137091    8Z1F_T
137092    8Z1F_T
137093    8Z1F_T
137094    8Z1F_T
Name: pdb_id, Length: 137095, dtype: object

In [6]:
float('Nan')

nan

In [7]:
all_xyz=[]

for pdb_id in tqdm(train_sequences['target_id']):
    df = train_labels[train_labels["pdb_id"]==pdb_id]
    #break
    xyz=df[['x_1','y_1','z_1']].to_numpy().astype('float32')
    xyz[xyz<-1e17]=float('Nan');
    all_xyz.append(xyz)


df

  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xyz<-1e17]=float('Nan');
  xyz[xy

Unnamed: 0,ID,resname,resid,x_1,y_1,z_1,pdb_id
137009,8Z1F_T_1,G,1,103.195999,112.250999,104.455002,8Z1F_T
137010,8Z1F_T_2,G,2,107.467003,108.984001,106.205002,8Z1F_T
137011,8Z1F_T_3,U,3,111.919998,107.942001,109.775002,8Z1F_T
137012,8Z1F_T_4,A,4,114.685997,108.813004,114.404999,8Z1F_T
137013,8Z1F_T_5,A,5,114.921997,110.031998,120.849998,8Z1F_T
...,...,...,...,...,...,...,...
137090,8Z1F_T_82,U,82,,,,8Z1F_T
137091,8Z1F_T_83,C,83,,,,8Z1F_T
137092,8Z1F_T_84,A,84,,,,8Z1F_T
137093,8Z1F_T_85,U,85,,,,8Z1F_T


In [8]:
# filter the data
# Filter and process data
filter_nan = []
max_len = 0
for xyz in all_xyz:
    if len(xyz) > max_len:
        max_len = len(xyz)

    #fill -1e18 masked sequences to nans
    
    #sugar_xyz = np.stack([nt_xyz['sugar_ring'] for nt_xyz in xyz], axis=0)
    filter_nan.append((np.isnan(xyz).mean() <= 0.5) & \
                      (len(xyz)<config['max_len_filter']) & \
                      (len(xyz)>config['min_len_filter']))

print(f"Longest sequence in train: {max_len}")

filter_nan = np.array(filter_nan)
non_nan_indices = np.arange(len(filter_nan))[filter_nan]

train_sequences = train_sequences.loc[non_nan_indices].reset_index(drop=True)
all_xyz=[all_xyz[i] for i in non_nan_indices]

Longest sequence in train: 4298


In [9]:
#pack data into a dictionary

data={
      "sequence":train_sequences['sequence'].to_list(),
      "temporal_cutoff": train_sequences['temporal_cutoff'].to_list(),
      "description": train_sequences['description'].to_list(),
      "all_sequences": train_sequences['all_sequences'].to_list(),
      "xyz": all_xyz
}

# Split train data into train/val/test¶
We will simply do a temporal split, because that's how testing is done in structural biology in general (in actual blind tests)

In [10]:
# Split data into train and test
all_index = np.arange(len(data['sequence']))
cutoff_date = pd.Timestamp(config['cutoff_date'])
test_cutoff_date = pd.Timestamp(config['test_cutoff_date'])
train_index = [i for i, d in enumerate(data['temporal_cutoff']) if pd.Timestamp(d) <= cutoff_date]
test_index = [i for i, d in enumerate(data['temporal_cutoff']) if pd.Timestamp(d) > cutoff_date and pd.Timestamp(d) <= test_cutoff_date]

In [11]:
print(f"Train size: {len(train_index)}")
print(f"Test size: {len(test_index)}")

Train size: 542
Test size: 80


# Get pytorch dataset¶

In [12]:
from torch.utils.data import Dataset, DataLoader
from ast import literal_eval

def get_ct(bp,s):
    ct_matrix=np.zeros((len(s),len(s)))
    for b in bp:
        ct_matrix[b[0]-1,b[1]-1]=1
    return ct_matrix

class RNA3D_Dataset(Dataset):
    def __init__(self,indices,data):
        self.indices=indices
        self.data=data
        self.tokens={nt:i for i,nt in enumerate('ACGU')}

    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):

        idx=self.indices[idx]
        sequence=[self.tokens[nt] for nt in (self.data['sequence'][idx])]
        sequence=np.array(sequence)
        sequence=torch.tensor(sequence)

        #get C1' xyz
        xyz=self.data['xyz'][idx]
        xyz=torch.tensor(np.array(xyz))


        if len(sequence)>config['max_len']:
            crop_start=np.random.randint(len(sequence)-config['max_len'])
            crop_end=crop_start+config['max_len']

            sequence=sequence[crop_start:crop_end]
            xyz=xyz[crop_start:crop_end]
        

        return {'sequence':sequence,
                'xyz':xyz}

In [13]:
train_dataset=RNA3D_Dataset(train_index,data)
val_dataset=RNA3D_Dataset(test_index,data)

In [14]:
import plotly.graph_objects as go
import numpy as np



# Example: Generate an Nx3 matrix
xyz = train_dataset[200]['xyz']  # Replace this with your actual Nx3 data
N = len(xyz)


for _ in range(2): #plot twice because it doesnt show up on first try for some reason
    # Extract columns
    x, y, z = xyz[:, 0], xyz[:, 1], xyz[:, 2]
    
    # Create the 3D scatter plot
    fig = go.Figure(data=[go.Scatter3d(
        x=x, y=y, z=z,
        mode='markers',
        marker=dict(
            size=5,
            color=z,  # Coloring based on z-value
            colorscale='Viridis',  # Choose a colorscale
            opacity=0.8
        )
    )])
    
    # Customize layout
    fig.update_layout(
        scene=dict(
            xaxis_title="X",
            yaxis_title="Y",
            zaxis_title="Z"
        ),
        title="3D Scatter Plot"
    )

fig.show()
    

In [15]:
train_loader=DataLoader(train_dataset,batch_size=1,shuffle=True)
val_loader=DataLoader(val_dataset,batch_size=1,shuffle=False)

# Get RibonanzaNet¶
We will add a linear layer to predict xyz of C1' atoms

In [16]:
! pip install einops




In [17]:
import sys

sys.path.append("/kaggle/input/ribonanzanet2d-final")


from Network import *
import yaml



class Config:
    def __init__(self, **entries):
        self.__dict__.update(entries)
        self.entries=entries

    def print(self):
        print(self.entries)

def load_config_from_yaml(file_path):
    with open(file_path, 'r') as file:
        config = yaml.safe_load(file)
    return Config(**config)



class finetuned_RibonanzaNet(RibonanzaNet):
    def __init__(self, config, pretrained=False):
        config.dropout=0.1
        super(finetuned_RibonanzaNet, self).__init__(config)
        if pretrained:
            self.load_state_dict(torch.load("/kaggle/input/ribonanzanet-weights/RibonanzaNet.pt",map_location='cpu'))
        # self.ct_predictor=nn.Sequential(nn.Linear(64,256),
        #                                 nn.ReLU(),
        #                                 nn.Linear(256,64),
        #                                 nn.ReLU(),
        #                                 nn.Linear(64,1)) 
        self.dropout=nn.Dropout(0.0)
        self.xyz_predictor=nn.Linear(256,3)


    
    def forward(self,src):
        
        #with torch.no_grad():
        sequence_features, pairwise_features=self.get_embeddings(src, torch.ones_like(src).long().to(src.device))


        xyz=self.xyz_predictor(sequence_features)

        return xyz

In [18]:
model=finetuned_RibonanzaNet(load_config_from_yaml("/kaggle/input/ribonanzanet2d-final/configs/pairwise.yaml"),pretrained=True).cuda()


constructing 9 ConvTransformerEncoderLayers



You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.



# Training loop¶
we will use dRMSD loss on the predicted xyz. the loss function is invariant to translations, rotations, and reflections. because dRMSD is invariant to reflections, it cannot distinguish chiral structures, so there may be better loss functions

In [19]:
def calculate_distance_matrix(X,Y,epsilon=1e-4):
    return (torch.square(X[:,None]-Y[None,:])+epsilon).sum(-1).sqrt()


def dRMSD(pred_x,
          pred_y,
          gt_x,
          gt_y,
          epsilon=1e-4,Z=10,d_clamp=None):
    pred_dm=calculate_distance_matrix(pred_x,pred_y)
    gt_dm=calculate_distance_matrix(gt_x,gt_y)



    mask=~torch.isnan(gt_dm)
    mask[torch.eye(mask.shape[0]).bool()]=False

    if d_clamp is not None:
        rmsd=(torch.square(pred_dm[mask]-gt_dm[mask])+epsilon).clip(0,d_clamp**2)
    else:
        rmsd=torch.square(pred_dm[mask]-gt_dm[mask])+epsilon

    return rmsd.sqrt().mean()/Z

def local_dRMSD(pred_x,
          pred_y,
          gt_x,
          gt_y,
          epsilon=1e-4,Z=10,d_clamp=30):
    pred_dm=calculate_distance_matrix(pred_x,pred_y)
    gt_dm=calculate_distance_matrix(gt_x,gt_y)



    mask=(~torch.isnan(gt_dm))*(gt_dm<d_clamp)
    mask[torch.eye(mask.shape[0]).bool()]=False



    rmsd=torch.square(pred_dm[mask]-gt_dm[mask])+epsilon
    # rmsd=(torch.square(pred_dm[mask]-gt_dm[mask])+epsilon).sqrt()/Z
    #rmsd=torch.abs(pred_dm[mask]-gt_dm[mask])/Z
    return rmsd.sqrt().mean()/Z

def dRMAE(pred_x,
          pred_y,
          gt_x,
          gt_y,
          epsilon=1e-4,Z=10,d_clamp=None):
    pred_dm=calculate_distance_matrix(pred_x,pred_y)
    gt_dm=calculate_distance_matrix(gt_x,gt_y)



    mask=~torch.isnan(gt_dm)
    mask[torch.eye(mask.shape[0]).bool()]=False

    rmsd=torch.abs(pred_dm[mask]-gt_dm[mask])

    return rmsd.mean()/Z

import torch

def align_svd_mae(input, target, Z=10):
    """
    Aligns the input (Nx3) to target (Nx3) using SVD-based Procrustes alignment
    and computes RMSD loss.
    
    Args:
        input (torch.Tensor): Nx3 tensor representing the input points.
        target (torch.Tensor): Nx3 tensor representing the target points.
    
    Returns:
        aligned_input (torch.Tensor): Nx3 aligned input.
        rmsd_loss (torch.Tensor): RMSD loss.
    """
    assert input.shape == target.shape, "Input and target must have the same shape"

    #mask 
    mask=~torch.isnan(target.sum(-1))

    input=input[mask]
    target=target[mask]
    
    # Compute centroids
    centroid_input = input.mean(dim=0, keepdim=True)
    centroid_target = target.mean(dim=0, keepdim=True)

    # Center the points
    input_centered = input - centroid_input.detach()
    target_centered = target - centroid_target

    # Compute covariance matrix
    cov_matrix = input_centered.T @ target_centered

    # SVD to find optimal rotation
    U, S, Vt = torch.svd(cov_matrix)

    # Compute rotation matrix
    R = Vt @ U.T

    # Ensure a proper rotation (det(R) = 1, no reflection)
    if torch.det(R) < 0:
        Vt[-1, :] *= -1
        R = Vt @ U.T

    # Rotate input
    aligned_input = (input_centered @ R.T.detach()) + centroid_target.detach()

    # # Compute RMSD loss
    # rmsd_loss = torch.sqrt(((aligned_input - target) ** 2).mean())

    # rmsd_loss = torch.sqrt(((aligned_input - target) ** 2).mean())
    
    # return aligned_input, rmsd_loss
    return torch.abs(aligned_input-target).mean()/Z

In [20]:
from tqdm import tqdm
from torch.amp import GradScaler

epochs=50
cos_epoch=35


best_loss=np.inf
optimizer = torch.optim.Adam(model.parameters(), weight_decay=0.0, lr=0.0001) #no weight decay following AF

batch_size=1

#for cycle in range(2):

criterion=torch.nn.BCEWithLogitsLoss(reduction='none')

scaler = GradScaler()

schedule=torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(epochs-cos_epoch)*len(train_loader)//batch_size)

best_val_loss=99999999999
for epoch in range(epochs):
    model.train()
    tbar=tqdm(train_loader)
    total_loss=0
    oom=0
    for idx, batch in enumerate(tbar):
        #try:
        sequence=batch['sequence'].cuda()
        gt_xyz=batch['xyz'].cuda().squeeze()

        #with torch.autocast(device_type='cuda', dtype=torch.float16):
        pred_xyz=model(sequence).squeeze()
        
        loss=dRMAE(pred_xyz,pred_xyz,gt_xyz,gt_xyz) + align_svd_mae(pred_xyz, gt_xyz)
             #local_dRMSD(pred_xyz,pred_xyz,gt_xyz,gt_xyz)

        if loss!=loss:
            stop

        
        (loss/batch_size).backward()

        if (idx+1)%batch_size==0 or idx+1 == len(tbar):

            torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            optimizer.step()
            optimizer.zero_grad()
            # scaler.scale(loss/batch_size).backward()
            # scaler.unscale_(optimizer)
            # torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            # scaler.step(optimizer)
            # scaler.update()

            
            if (epoch+1)>cos_epoch:
                schedule.step()
        #schedule.step()
        total_loss+=loss.item()
        
        tbar.set_description(f"Epoch {epoch + 1} Loss: {total_loss/(idx+1)} OOMs: {oom}")



        # except Exception:
        #     #print(Exception)
        #     oom+=1
    tbar=tqdm(val_loader)
    model.eval()
    val_preds=[]
    val_loss=0
    for idx, batch in enumerate(tbar):
        sequence=batch['sequence'].cuda()
        gt_xyz=batch['xyz'].cuda().squeeze()

        with torch.no_grad():
            pred_xyz=model(sequence).squeeze()
            loss=dRMAE(pred_xyz,pred_xyz,gt_xyz,gt_xyz)
            
        val_loss+=loss.item()
        val_preds.append([gt_xyz.cpu().numpy(),pred_xyz.cpu().numpy()])
    val_loss=val_loss/len(tbar)
    print(f"val loss: {val_loss}")
    
    
    
    if val_loss<best_val_loss:
        best_val_loss=val_loss
        best_preds=val_preds
        torch.save(model.state_dict(),'RibonanzaNet-3D.pt')

    # 1.053595052265986 train loss after epoch 0
torch.save(model.state_dict(),'RibonanzaNet-3D-final.pt')


torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.

Epoch 1 Loss: 3.2199990492025425 OOMs: 0: 100%|██████████| 542/542 [01:53<00:00,  4.80it/s]

None of the inputs have requires_grad=True. Gradients will be None

100%|██████████| 80/80 [00:06<00:00, 12.64it/s]


val loss: 3.150636697560549


Epoch 2 Loss: 2.5668994640292278 OOMs: 0: 100%|██████████| 542/542 [01:48<00:00,  4.99it/s]
100%|██████████| 80/80 [00:06<00:00, 12.70it/s]


val loss: 2.5698969673365353


Epoch 3 Loss: 2.0072686172939314 OOMs: 0: 100%|██████████| 542/542 [01:48<00:00,  4.98it/s]
100%|██████████| 80/80 [00:06<00:00, 12.71it/s]


val loss: 2.0730048954486846


Epoch 4 Loss: 1.7682882794814796 OOMs: 0: 100%|██████████| 542/542 [01:49<00:00,  4.97it/s]
100%|██████████| 80/80 [00:06<00:00, 12.76it/s]


val loss: 1.9012155197560787


Epoch 5 Loss: 1.5840497852691424 OOMs: 0: 100%|██████████| 542/542 [01:49<00:00,  4.94it/s]
100%|██████████| 80/80 [00:06<00:00, 12.70it/s]


val loss: 1.7273695463314653


Epoch 6 Loss: 1.4373803024912233 OOMs: 0: 100%|██████████| 542/542 [01:49<00:00,  4.94it/s]
100%|██████████| 80/80 [00:06<00:00, 12.56it/s]


val loss: 1.5960239328444004


Epoch 7 Loss: 1.3413980307944147 OOMs: 0: 100%|██████████| 542/542 [01:51<00:00,  4.88it/s]
100%|██████████| 80/80 [00:06<00:00, 12.63it/s]


val loss: 1.5905947024002671


Epoch 8 Loss: 1.2375505682067238 OOMs: 0: 100%|██████████| 542/542 [01:51<00:00,  4.85it/s]
100%|██████████| 80/80 [00:06<00:00, 12.68it/s]


val loss: 1.526562505401671


Epoch 9 Loss: 1.132738844165063 OOMs: 0: 100%|██████████| 542/542 [01:51<00:00,  4.87it/s]
100%|██████████| 80/80 [00:06<00:00, 12.74it/s]


val loss: 1.4162183694541455


Epoch 10 Loss: 1.0806735837041672 OOMs: 0: 100%|██████████| 542/542 [01:50<00:00,  4.91it/s]
100%|██████████| 80/80 [00:06<00:00, 12.68it/s]


val loss: 1.411273181065917


Epoch 11 Loss: 1.0130509767970035 OOMs: 0: 100%|██████████| 542/542 [01:49<00:00,  4.95it/s]
100%|██████████| 80/80 [00:06<00:00, 12.72it/s]


val loss: 1.358323416672647


Epoch 12 Loss: 0.989150149081026 OOMs: 0: 100%|██████████| 542/542 [01:48<00:00,  4.97it/s]
100%|██████████| 80/80 [00:06<00:00, 12.73it/s]


val loss: 1.386739151366055


Epoch 13 Loss: 0.9635479222859403 OOMs: 0: 100%|██████████| 542/542 [01:49<00:00,  4.95it/s]
100%|██████████| 80/80 [00:06<00:00, 12.57it/s]


val loss: 1.4007564082741737


Epoch 14 Loss: 0.934150873460013 OOMs: 0: 100%|██████████| 542/542 [01:52<00:00,  4.82it/s]
100%|██████████| 80/80 [00:06<00:00, 12.56it/s]


val loss: 1.3200245710089802


Epoch 15 Loss: 0.9049008390230446 OOMs: 0: 100%|██████████| 542/542 [01:52<00:00,  4.83it/s]
100%|██████████| 80/80 [00:06<00:00, 12.58it/s]


val loss: 1.3103256629779936


Epoch 16 Loss: 0.8676023795595909 OOMs: 0: 100%|██████████| 542/542 [01:53<00:00,  4.80it/s]
100%|██████████| 80/80 [00:06<00:00, 12.56it/s]


val loss: 1.36582504324615


Epoch 17 Loss: 0.8412119717890486 OOMs: 0: 100%|██████████| 542/542 [01:53<00:00,  4.79it/s]
100%|██████████| 80/80 [00:06<00:00, 12.50it/s]


val loss: 1.2612878596410155


Epoch 18 Loss: 0.8331168623318091 OOMs: 0: 100%|██████████| 542/542 [01:52<00:00,  4.84it/s]
100%|██████████| 80/80 [00:06<00:00, 12.65it/s]


val loss: 1.2719830891117454


Epoch 19 Loss: 0.832197459256517 OOMs: 0: 100%|██████████| 542/542 [01:50<00:00,  4.91it/s]
100%|██████████| 80/80 [00:06<00:00, 12.63it/s]


val loss: 1.2272241713479162


Epoch 20 Loss: 0.8011726759841521 OOMs: 0: 100%|██████████| 542/542 [01:50<00:00,  4.91it/s]
100%|██████████| 80/80 [00:06<00:00, 12.61it/s]


val loss: 1.1903072035871447


Epoch 21 Loss: 0.7872996571649045 OOMs: 0: 100%|██████████| 542/542 [01:49<00:00,  4.94it/s]
100%|██████████| 80/80 [00:06<00:00, 12.72it/s]


val loss: 1.2523045639507473


Epoch 22 Loss: 0.7692655583919195 OOMs: 0: 100%|██████████| 542/542 [01:49<00:00,  4.95it/s]
100%|██████████| 80/80 [00:06<00:00, 12.69it/s]


val loss: 1.1895527552813292


Epoch 23 Loss: 0.7515676390255949 OOMs: 0: 100%|██████████| 542/542 [01:50<00:00,  4.89it/s]
100%|██████████| 80/80 [00:06<00:00, 12.59it/s]


val loss: 1.2062565283849835


Epoch 24 Loss: 0.7195423061777305 OOMs: 0: 100%|██████████| 542/542 [01:50<00:00,  4.89it/s]
100%|██████████| 80/80 [00:06<00:00, 12.50it/s]


val loss: 1.209321746043861


Epoch 25 Loss: 0.7355265551387604 OOMs: 0: 100%|██████████| 542/542 [01:49<00:00,  4.95it/s]
100%|██████████| 80/80 [00:06<00:00, 12.65it/s]


val loss: 1.196804241091013


Epoch 26 Loss: 0.711640034544512 OOMs: 0: 100%|██████████| 542/542 [01:50<00:00,  4.93it/s]
100%|██████████| 80/80 [00:06<00:00, 12.66it/s]


val loss: 1.1817723407410086


Epoch 27 Loss: 0.6983844024228874 OOMs: 0: 100%|██████████| 542/542 [01:50<00:00,  4.92it/s]
100%|██████████| 80/80 [00:06<00:00, 12.79it/s]


val loss: 1.1909394411370158


Epoch 28 Loss: 0.6949924044036118 OOMs: 0: 100%|██████████| 542/542 [01:49<00:00,  4.96it/s]
100%|██████████| 80/80 [00:06<00:00, 12.61it/s]


val loss: 1.1709523915313185


Epoch 29 Loss: 0.6807934867967319 OOMs: 0: 100%|██████████| 542/542 [01:50<00:00,  4.90it/s]
100%|██████████| 80/80 [00:06<00:00, 12.60it/s]


val loss: 1.1749008664861322


Epoch 30 Loss: 0.672400250164895 OOMs: 0: 100%|██████████| 542/542 [01:50<00:00,  4.91it/s]
100%|██████████| 80/80 [00:06<00:00, 12.62it/s]


val loss: 1.1667254129424691


Epoch 31 Loss: 0.6531378618300621 OOMs: 0: 100%|██████████| 542/542 [01:50<00:00,  4.90it/s]
100%|██████████| 80/80 [00:06<00:00, 12.56it/s]


val loss: 1.1292256562039256


Epoch 32 Loss: 0.6421718541464022 OOMs: 0: 100%|██████████| 542/542 [01:51<00:00,  4.88it/s]
100%|██████████| 80/80 [00:06<00:00, 12.67it/s]


val loss: 1.2099026416428387


Epoch 33 Loss: 0.6324000246410441 OOMs: 0: 100%|██████████| 542/542 [01:50<00:00,  4.90it/s]
100%|██████████| 80/80 [00:06<00:00, 12.63it/s]


val loss: 1.1636741166003048


Epoch 34 Loss: 0.6369210529618817 OOMs: 0: 100%|██████████| 542/542 [01:50<00:00,  4.88it/s]
100%|██████████| 80/80 [00:06<00:00, 12.62it/s]


val loss: 1.2289577820338309


Epoch 35 Loss: 0.6191186050564821 OOMs: 0: 100%|██████████| 542/542 [01:50<00:00,  4.89it/s]
100%|██████████| 80/80 [00:06<00:00, 12.66it/s]


val loss: 1.1443675881251694


Epoch 36 Loss: 0.6168897275115292 OOMs: 0: 100%|██████████| 542/542 [01:51<00:00,  4.88it/s]
100%|██████████| 80/80 [00:06<00:00, 12.67it/s]


val loss: 1.1731052215211093


Epoch 37 Loss: 0.6279423405653436 OOMs: 0: 100%|██████████| 542/542 [01:51<00:00,  4.88it/s]
100%|██████████| 80/80 [00:06<00:00, 12.61it/s]


val loss: 1.1323314106091857


Epoch 38 Loss: 0.59934857963306 OOMs: 0: 100%|██████████| 542/542 [01:51<00:00,  4.88it/s]
100%|██████████| 80/80 [00:06<00:00, 12.64it/s]


val loss: 1.1745440794155002


Epoch 39 Loss: 0.5804181288639118 OOMs: 0: 100%|██████████| 542/542 [01:51<00:00,  4.88it/s]
100%|██████████| 80/80 [00:06<00:00, 12.55it/s]


val loss: 1.1445279031060636


Epoch 40 Loss: 0.5749931129939662 OOMs: 0: 100%|██████████| 542/542 [01:51<00:00,  4.87it/s]
100%|██████████| 80/80 [00:06<00:00, 12.65it/s]


val loss: 1.1050112588331102


Epoch 41 Loss: 0.5586001111737476 OOMs: 0: 100%|██████████| 542/542 [01:51<00:00,  4.85it/s]
100%|██████████| 80/80 [00:06<00:00, 12.57it/s]


val loss: 1.1341989050619303


Epoch 42 Loss: 0.5419458150176325 OOMs: 0: 100%|██████████| 542/542 [01:51<00:00,  4.86it/s]
100%|██████████| 80/80 [00:06<00:00, 12.48it/s]


val loss: 1.1200986428186297


Epoch 43 Loss: 0.5424057700124834 OOMs: 0: 100%|██████████| 542/542 [01:51<00:00,  4.86it/s]
100%|██████████| 80/80 [00:06<00:00, 12.66it/s]


val loss: 1.1600506937131285


Epoch 44 Loss: 0.5264444732110659 OOMs: 0: 100%|██████████| 542/542 [01:50<00:00,  4.89it/s]
100%|██████████| 80/80 [00:06<00:00, 12.67it/s]


val loss: 1.1446858283132315


Epoch 45 Loss: 0.5210942760577281 OOMs: 0: 100%|██████████| 542/542 [01:50<00:00,  4.89it/s]
100%|██████████| 80/80 [00:06<00:00, 12.66it/s]


val loss: 1.100903124921024


Epoch 46 Loss: 0.5189894529478797 OOMs: 0: 100%|██████████| 542/542 [01:50<00:00,  4.92it/s]
100%|██████████| 80/80 [00:06<00:00, 12.63it/s]


val loss: 1.0961286412551998


Epoch 47 Loss: 0.5146828567561078 OOMs: 0: 100%|██████████| 542/542 [01:50<00:00,  4.90it/s]
100%|██████████| 80/80 [00:06<00:00, 12.59it/s]


val loss: 1.1295567157678306


Epoch 48 Loss: 0.48946112457566593 OOMs: 0: 100%|██████████| 542/542 [01:51<00:00,  4.88it/s]
100%|██████████| 80/80 [00:06<00:00, 12.71it/s]


val loss: 1.1141766550950707


Epoch 49 Loss: 0.49631714188524717 OOMs: 0: 100%|██████████| 542/542 [01:50<00:00,  4.92it/s]
100%|██████████| 80/80 [00:06<00:00, 12.61it/s]


val loss: 1.1036432842724024


Epoch 50 Loss: 0.48943002711692857 OOMs: 0: 100%|██████████| 542/542 [01:50<00:00,  4.89it/s]
100%|██████████| 80/80 [00:06<00:00, 12.73it/s]


val loss: 1.1280615611933171
