In [None]:
import pandas as pd
import torch
import matplotlib.pyplot as plt
import numpy as np
import torch
import random
import pickle
from tqdm import tqdm

In [None]:
#set seed for everything
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

In [None]:
config = {
    "seed": 0,
    "cutoff_date": "2020-01-01",
    "test_cutoff_date": "2022-05-01",
    "max_len": 384,
    "batch_size": 1,
    "learning_rate": 1e-4,
    "weight_decay": 0.0,
    "mixed_precision": "bf16",
    "model_config_path": "../working/configs/pairwise.yaml",  # Adjust path as needed
    "epochs": 10,
    "cos_epoch": 5,
    "loss_power_scale": 1.0,
    "max_cycles": 1,
    "grad_clip": 0.1,
    "gradient_accumulation_steps": 1,
    "d_clamp": 30,
    "max_len_filter": 9999999,
    "min_len_filter": 10,
    "structural_violation_epoch": 50,
    "balance_weight": False,
}

In [None]:
# Load data

train_sequences=pd.read_csv("/content/drive/MyDrive/Colab Notebooks/RNA_DATA/stanford-rna-3d-folding/train_sequences.csv")
train_labels=pd.read_csv("/content/drive/MyDrive/Colab Notebooks/RNA_DATA/stanford-rna-3d-folding/train_labels.csv")

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
# apply: 데이터 변경
train_labels["pdb_id"] = train_labels["ID"].apply(lambda x: x.split("_")[0]+'_'+x.split("_")[1])
train_labels["pdb_id"]

Unnamed: 0,pdb_id
0,1SCL_A
1,1SCL_A
2,1SCL_A
3,1SCL_A
4,1SCL_A
...,...
137090,8Z1F_T
137091,8Z1F_T
137092,8Z1F_T
137093,8Z1F_T


In [None]:
float('Nan')

nan

In [None]:
all_xyz=[]

#
for pdb_id in tqdm(train_sequences["target_id"]):
  df = train_labels[train_labels["pdb_id"]==pdb_id]
  #break
  xyz=df[['x_1', 'y_1', 'z_1']].to_numpy().astype('float32')
  xyz[xyz<-1e17]=float('Nan');
  all_xyz.append(xyz)

df

100%|██████████| 844/844 [00:08<00:00, 97.48it/s]


Unnamed: 0,ID,resname,resid,x_1,y_1,z_1,pdb_id
137009,8Z1F_T_1,G,1,103.195999,112.251000,104.455002,8Z1F_T
137010,8Z1F_T_2,G,2,107.467003,108.984001,106.205002,8Z1F_T
137011,8Z1F_T_3,U,3,111.919998,107.942001,109.775002,8Z1F_T
137012,8Z1F_T_4,A,4,114.685997,108.813003,114.404999,8Z1F_T
137013,8Z1F_T_5,A,5,114.921997,110.031998,120.849998,8Z1F_T
...,...,...,...,...,...,...,...
137090,8Z1F_T_82,U,82,,,,8Z1F_T
137091,8Z1F_T_83,C,83,,,,8Z1F_T
137092,8Z1F_T_84,A,84,,,,8Z1F_T
137093,8Z1F_T_85,U,85,,,,8Z1F_T


In [None]:
# filter the data
# Filter and process data
filter_nan = []
max_len = 0
for xyz in all_xyz:
  if len(xyz) > max_len:
    max_len = len(xyz)

  #fill -1e18 masked sequences to nans

  #sugar_xyz = np.stack([nt_xyz['sugar_ring] for nt_xyz in xyz], axis=0)
  filter_nan.append((np.isnan(xyz).mean() <=0.5) & \
                    (len(xyz)<config['max_len_filter']) & \
                    (len(xyz)>config['min_len_filter']))
print(f"Longest sequence in train: {max_len}")

filter_nan = np.array(filter_nan)
non_nan_indices = np.arange(len(filter_nan))[filter_nan]

train_sequences = train_sequences.loc[non_nan_indices].reset_index(drop=True)
all_xyz=[all_xyz[i] for i in non_nan_indices]

Longest sequence in train: 4298


In [None]:
#pack data into a dictionary

data={
      "sequence":train_sequences['sequence'].to_list(),
      "temporal_cutoff": train_sequences['temporal_cutoff'].to_list(),
      "description": train_sequences['description'].to_list(),
      "all_sequences": train_sequences['all_sequences'].to_list(),
      "xyz": all_xyz
}

In [None]:
# Split data into train and test
all_index = np.arange(len(data['sequence']))
cutoff_date = pd.Timestamp(config['cutoff_date'])
test_cutoff_date = pd.Timestamp(config['test_cutoff_date'])
train_index = [i for i, d in enumerate(data['temporal_cutoff']) if pd.Timestamp(d) <= cutoff_date]
test_index = [i for i, d in enumerate(data['temporal_cutoff']) if pd.Timestamp(d) > cutoff_date and pd.Timestamp(d) <= test_cutoff_date]

In [None]:
print(f"Train size: {len(train_index)}")
print(f"Test size: {len(test_index)}")

Train size: 542
Test size: 80


In [None]:
from torch.utils.data import Dataset, DataLoader
from ast import literal_eval

def get_ct(bp,s):
    ct_matrix=np.zeros((len(s),len(s)))
    for b in bp:
        ct_matrix[b[0]-1,b[1]-1]=1
    return ct_matrix

class RNA3D_Dataset(Dataset):
    def __init__(self,indices,data):
        self.indices=indices
        self.data=data
        self.tokens={nt:i for i,nt in enumerate('ACGU')}

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):

        idx=self.indices[idx]
        sequence=[self.tokens[nt] for nt in (self.data['sequence'][idx])]
        sequence=np.array(sequence)
        sequence=torch.tensor(sequence)

        #get C1' xyz
        xyz=self.data['xyz'][idx]
        xyz=torch.tensor(np.array(xyz))


        if len(sequence)>config['max_len']:
            crop_start=np.random.randint(len(sequence)-config['max_len'])
            crop_end=crop_start+config['max_len']

            sequence=sequence[crop_start:crop_end]
            xyz=xyz[crop_start:crop_end]


        return {'sequence':sequence,
                'xyz':xyz}

In [None]:
train_dataset=RNA3D_Dataset(train_index,data)
val_dataset=RNA3D_Dataset(test_index,data)

In [None]:
import plotly.graph_objects as go
import numpy as np



# Example: Generate an Nx3 matrix
xyz = train_dataset[200]['xyz']  # Replace this with your actual Nx3 data
N = len(xyz)


for _ in range(2): #plot twice because it doesnt show up on first try for some reason
    # Extract columns
    x, y, z = xyz[:, 0], xyz[:, 1], xyz[:, 2]

    # Create the 3D scatter plot
    fig = go.Figure(data=[go.Scatter3d(
        x=x, y=y, z=z,
        mode='markers',
        marker=dict(
            size=5,
            color=z,  # Coloring based on z-value
            colorscale='Viridis',  # Choose a colorscale
            opacity=0.8
        )
    )])

    # Customize layout
    fig.update_layout(
        scene=dict(
            xaxis_title="X",
            yaxis_title="Y",
            zaxis_title="Z"
        ),
        title="3D Scatter Plot"
    )

fig.show()


In [None]:
train_loader=DataLoader(train_dataset,batch_size=1,shuffle=True)
val_loader=DataLoader(val_dataset,batch_size=1,shuffle=False)

In [None]:
! pip install einops



In [None]:
import sys

sys.path.append("/content/drive/MyDrive/Colab Notebooks/RNA_DATA/ribonanzanet2d-final")


from Network import *
import yaml



class Config:
    def __init__(self, **entries):
        self.__dict__.update(entries)
        self.entries=entries

    def print(self):
        print(self.entries)

def load_config_from_yaml(file_path):
    with open(file_path, 'r') as file:
        config = yaml.safe_load(file)
    return Config(**config)



class finetuned_RibonanzaNet(RibonanzaNet):
    def __init__(self, config, pretrained=False):
        config.dropout=0.1
        super(finetuned_RibonanzaNet, self).__init__(config)
        if pretrained:
            self.load_state_dict(torch.load("/content/drive/MyDrive/Colab Notebooks/RNA_DATA/ribonanzanet-weights/RibonanzaNet.pt",map_location='cpu'))
        # self.ct_predictor=nn.Sequential(nn.Linear(64,256),
        #                                 nn.ReLU(),
        #                                 nn.Linear(256,64),
        #                                 nn.ReLU(),
        #                                 nn.Linear(64,1))
        self.dropout=nn.Dropout(0.0)
        self.xyz_predictor=nn.Linear(256,3)



    def forward(self,src):

        #with torch.no_grad():
        sequence_features, pairwise_features=self.get_embeddings(src, torch.ones_like(src).long().to(src.device))


        xyz=self.xyz_predictor(sequence_features)

        return xyz

In [None]:
model=finetuned_RibonanzaNet(load_config_from_yaml("/content/drive/MyDrive/Colab Notebooks/RNA_DATA/ribonanzanet2d-final/configs/pairwise.yaml"),pretrained=True).cuda()


constructing 9 ConvTransformerEncoderLayers


In [None]:
def calculate_distance_matrix(X,Y,epsilon=1e-4):
    return (torch.square(X[:,None]-Y[None,:])+epsilon).sum(-1).sqrt()


def dRMSD(pred_x,
          pred_y,
          gt_x,
          gt_y,
          epsilon=1e-4,Z=10,d_clamp=None):
    pred_dm=calculate_distance_matrix(pred_x,pred_y)
    gt_dm=calculate_distance_matrix(gt_x,gt_y)



    mask=~torch.isnan(gt_dm)
    mask[torch.eye(mask.shape[0]).bool()]=False

    if d_clamp is not None:
        rmsd=(torch.square(pred_dm[mask]-gt_dm[mask])+epsilon).clip(0,d_clamp**2)
    else:
        rmsd=torch.square(pred_dm[mask]-gt_dm[mask])+epsilon

    return rmsd.sqrt().mean()/Z

def local_dRMSD(pred_x,
          pred_y,
          gt_x,
          gt_y,
          epsilon=1e-4,Z=10,d_clamp=30):
    pred_dm=calculate_distance_matrix(pred_x,pred_y)
    gt_dm=calculate_distance_matrix(gt_x,gt_y)



    mask=(~torch.isnan(gt_dm))*(gt_dm<d_clamp)
    mask[torch.eye(mask.shape[0]).bool()]=False



    rmsd=torch.square(pred_dm[mask]-gt_dm[mask])+epsilon
    # rmsd=(torch.square(pred_dm[mask]-gt_dm[mask])+epsilon).sqrt()/Z
    #rmsd=torch.abs(pred_dm[mask]-gt_dm[mask])/Z
    return rmsd.sqrt().mean()/Z

def dRMAE(pred_x,
          pred_y,
          gt_x,
          gt_y,
          epsilon=1e-4,Z=10,d_clamp=None):
    pred_dm=calculate_distance_matrix(pred_x,pred_y)
    gt_dm=calculate_distance_matrix(gt_x,gt_y)



    mask=~torch.isnan(gt_dm)
    mask[torch.eye(mask.shape[0]).bool()]=False

    rmsd=torch.abs(pred_dm[mask]-gt_dm[mask])

    return rmsd.mean()/Z

import torch

def align_svd_mae(input, target, Z=10):
    """
    Aligns the input (Nx3) to target (Nx3) using SVD-based Procrustes alignment
    and computes RMSD loss.

    Args:
        input (torch.Tensor): Nx3 tensor representing the input points.
        target (torch.Tensor): Nx3 tensor representing the target points.

    Returns:
        aligned_input (torch.Tensor): Nx3 aligned input.
        rmsd_loss (torch.Tensor): RMSD loss.
    """
    assert input.shape == target.shape, "Input and target must have the same shape"

    #mask
    mask=~torch.isnan(target.sum(-1))

    input=input[mask]
    target=target[mask]

    # Compute centroids
    centroid_input = input.mean(dim=0, keepdim=True)
    centroid_target = target.mean(dim=0, keepdim=True)

    # Center the points
    input_centered = input - centroid_input.detach()
    target_centered = target - centroid_target

    # Compute covariance matrix
    cov_matrix = input_centered.T @ target_centered

    # SVD to find optimal rotation
    U, S, Vt = torch.svd(cov_matrix)

    # Compute rotation matrix
    R = Vt @ U.T

    # Ensure a proper rotation (det(R) = 1, no reflection)
    if torch.det(R) < 0:
        Vt[-1, :] *= -1
        R = Vt @ U.T

    # Rotate input
    aligned_input = (input_centered @ R.T.detach()) + centroid_target.detach()

    # # Compute RMSD loss
    # rmsd_loss = torch.sqrt(((aligned_input - target) ** 2).mean())

    # rmsd_loss = torch.sqrt(((aligned_input - target) ** 2).mean())

    # return aligned_input, rmsd_loss
    return torch.abs(aligned_input-target).mean()/Z

In [None]:
from tqdm import tqdm
from torch.amp import GradScaler

epochs=50
cos_epoch=35


best_loss=np.inf
optimizer = torch.optim.Adam(model.parameters(), weight_decay=0.0, lr=0.0001) #no weight decay following AF

batch_size=1

#for cycle in range(2):

criterion=torch.nn.BCEWithLogitsLoss(reduction='none')

scaler = GradScaler()

schedule=torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(epochs-cos_epoch)*len(train_loader)//batch_size)

best_val_loss=99999999999
for epoch in range(epochs):
    model.train()
    tbar=tqdm(train_loader)
    total_loss=0
    oom=0
    for idx, batch in enumerate(tbar):
        #try:
        sequence=batch['sequence'].cuda()
        gt_xyz=batch['xyz'].cuda().squeeze()

        #with torch.autocast(device_type='cuda', dtype=torch.float16):
        pred_xyz=model(sequence).squeeze()

        loss=dRMAE(pred_xyz,pred_xyz,gt_xyz,gt_xyz) + align_svd_mae(pred_xyz, gt_xyz)
             #local_dRMSD(pred_xyz,pred_xyz,gt_xyz,gt_xyz)

        if loss!=loss:
            stop


        (loss/batch_size).backward()

        if (idx+1)%batch_size==0 or idx+1 == len(tbar):

            torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            optimizer.step()
            optimizer.zero_grad()
            # scaler.scale(loss/batch_size).backward()
            # scaler.unscale_(optimizer)
            # torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            # scaler.step(optimizer)
            # scaler.update()


            if (epoch+1)>cos_epoch:
                schedule.step()
        #schedule.step()
        total_loss+=loss.item()

        tbar.set_description(f"Epoch {epoch + 1} Loss: {total_loss/(idx+1)} OOMs: {oom}")



        # except Exception:
        #     #print(Exception)
        #     oom+=1
    tbar=tqdm(val_loader)
    model.eval()
    val_preds=[]
    val_loss=0
    for idx, batch in enumerate(tbar):
        sequence=batch['sequence'].cuda()
        gt_xyz=batch['xyz'].cuda().squeeze()

        with torch.no_grad():
            pred_xyz=model(sequence).squeeze()
            loss=dRMAE(pred_xyz,pred_xyz,gt_xyz,gt_xyz)

        val_loss+=loss.item()
        val_preds.append([gt_xyz.cpu().numpy(),pred_xyz.cpu().numpy()])
    val_loss=val_loss/len(tbar)
    print(f"val loss: {val_loss}")



    if val_loss<best_val_loss:
        best_val_loss=val_loss
        best_preds=val_preds
        torch.save(model.state_dict(),'RibonanzaNet-3D.pt')

    # 1.053595052265986 train loss after epoch 0
torch.save(model.state_dict(),'RibonanzaNet-3D-final.pt')


torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.

Epoch 1 Loss: 3.2220509021484545 OOMs: 0: 100%|██████████| 542/542 [01:32<00:00,  5.86it/s]

None of the inputs have requires_grad=True. Gradients will be None

100%|██████████| 80/80 [00:04<00:00, 18.23it/s]


val loss: 3.147537238895893


Epoch 2 Loss: 2.5763137090909964 OOMs: 0: 100%|██████████| 542/542 [01:29<00:00,  6.09it/s]
100%|██████████| 80/80 [00:04<00:00, 18.35it/s]


val loss: 2.5725802034139633


Epoch 3 Loss: 2.060368769118267 OOMs: 0: 100%|██████████| 542/542 [01:29<00:00,  6.06it/s]
100%|██████████| 80/80 [00:04<00:00, 18.52it/s]


val loss: 2.073958469182253


Epoch 4 Loss: 1.8677558203904832 OOMs: 0: 100%|██████████| 542/542 [01:29<00:00,  6.06it/s]
100%|██████████| 80/80 [00:04<00:00, 18.49it/s]


val loss: 1.864227869734168


Epoch 5 Loss: 1.7094296843803236 OOMs: 0: 100%|██████████| 542/542 [01:28<00:00,  6.12it/s]
100%|██████████| 80/80 [00:04<00:00, 18.39it/s]


val loss: 1.7439701043069362


Epoch 6 Loss: 1.6031643666904352 OOMs: 0: 100%|██████████| 542/542 [01:28<00:00,  6.15it/s]
100%|██████████| 80/80 [00:04<00:00, 18.48it/s]


val loss: 1.6472794273868203


Epoch 7 Loss: 1.5322405225786335 OOMs: 0: 100%|██████████| 542/542 [01:28<00:00,  6.15it/s]
100%|██████████| 80/80 [00:04<00:00, 18.15it/s]


val loss: 1.5996657667681575


Epoch 8 Loss: 1.461285957386133 OOMs: 0: 100%|██████████| 542/542 [01:28<00:00,  6.13it/s]
100%|██████████| 80/80 [00:04<00:00, 18.57it/s]


val loss: 1.539158627949655


Epoch 9 Loss: 1.4466680658369486 OOMs: 0: 100%|██████████| 542/542 [01:28<00:00,  6.10it/s]
100%|██████████| 80/80 [00:04<00:00, 18.40it/s]


val loss: 1.472761670127511


Epoch 10 Loss: 1.419885346920288 OOMs: 0: 100%|██████████| 542/542 [01:28<00:00,  6.09it/s]
100%|██████████| 80/80 [00:04<00:00, 18.30it/s]


val loss: 1.4416812611743808


Epoch 11 Loss: 1.350555289935362 OOMs: 0: 100%|██████████| 542/542 [01:29<00:00,  6.07it/s]
100%|██████████| 80/80 [00:04<00:00, 18.33it/s]


val loss: 1.4564144374802708


Epoch 12 Loss: 1.3354617590943825 OOMs: 0: 100%|██████████| 542/542 [01:29<00:00,  6.05it/s]
100%|██████████| 80/80 [00:04<00:00, 18.56it/s]


val loss: 1.417674846574664


Epoch 13 Loss: 1.3132891990931712 OOMs: 0: 100%|██████████| 542/542 [01:29<00:00,  6.06it/s]
100%|██████████| 80/80 [00:04<00:00, 18.38it/s]


val loss: 1.4672505352646112


Epoch 14 Loss: 1.2684290409088135 OOMs: 0: 100%|██████████| 542/542 [01:29<00:00,  6.06it/s]
100%|██████████| 80/80 [00:04<00:00, 18.36it/s]


val loss: 1.3901098381727934


Epoch 15 Loss: 1.254725495167764 OOMs: 0: 100%|██████████| 542/542 [01:28<00:00,  6.11it/s]
100%|██████████| 80/80 [00:04<00:00, 18.38it/s]


val loss: 1.3784112486056983


Epoch 16 Loss: 1.2163728261125923 OOMs: 0: 100%|██████████| 542/542 [01:28<00:00,  6.13it/s]
100%|██████████| 80/80 [00:04<00:00, 18.36it/s]


val loss: 1.373690372891724


Epoch 17 Loss: 1.1904584686254662 OOMs: 0: 100%|██████████| 542/542 [01:28<00:00,  6.15it/s]
100%|██████████| 80/80 [00:04<00:00, 18.26it/s]


val loss: 1.2767485700547696


Epoch 18 Loss: 1.159861091593095 OOMs: 0: 100%|██████████| 542/542 [01:28<00:00,  6.15it/s]
100%|██████████| 80/80 [00:04<00:00, 18.50it/s]


val loss: 1.2774228800088168


Epoch 19 Loss: 1.1656664123284421 OOMs: 0: 100%|██████████| 542/542 [01:27<00:00,  6.18it/s]
100%|██████████| 80/80 [00:04<00:00, 18.43it/s]


val loss: 1.2387274676933884


Epoch 20 Loss: 1.125452289557105 OOMs: 0: 100%|██████████| 542/542 [01:28<00:00,  6.10it/s]
100%|██████████| 80/80 [00:04<00:00, 18.55it/s]


val loss: 1.2371392032131552


Epoch 21 Loss: 1.110833135478171 OOMs: 0: 100%|██████████| 542/542 [01:28<00:00,  6.15it/s]
100%|██████████| 80/80 [00:04<00:00, 18.51it/s]


val loss: 1.2808918325230478


Epoch 22 Loss: 1.0902540876746618 OOMs: 0: 100%|██████████| 542/542 [01:27<00:00,  6.17it/s]
100%|██████████| 80/80 [00:04<00:00, 18.13it/s]


val loss: 1.3068129571154714


Epoch 23 Loss: 1.0874614693581839 OOMs: 0: 100%|██████████| 542/542 [01:28<00:00,  6.11it/s]
100%|██████████| 80/80 [00:04<00:00, 18.39it/s]


val loss: 1.245852106437087


Epoch 24 Loss: 1.0592550768953408 OOMs: 0: 100%|██████████| 542/542 [01:27<00:00,  6.18it/s]
100%|██████████| 80/80 [00:04<00:00, 18.48it/s]


val loss: 1.2056237425655127


Epoch 25 Loss: 1.0624612251440977 OOMs: 0: 100%|██████████| 542/542 [01:27<00:00,  6.17it/s]
100%|██████████| 80/80 [00:04<00:00, 18.35it/s]


val loss: 1.2168593298643828


Epoch 26 Loss: 1.0520553661448488 OOMs: 0: 100%|██████████| 542/542 [01:28<00:00,  6.13it/s]
100%|██████████| 80/80 [00:04<00:00, 18.43it/s]


val loss: 1.1653822053223848


Epoch 27 Loss: 1.0269857373954625 OOMs: 0: 100%|██████████| 542/542 [01:28<00:00,  6.13it/s]
100%|██████████| 80/80 [00:04<00:00, 18.52it/s]


val loss: 1.1986798707395792


Epoch 28 Loss: 1.0253622567631662 OOMs: 0: 100%|██████████| 542/542 [01:28<00:00,  6.12it/s]
100%|██████████| 80/80 [00:04<00:00, 18.40it/s]


val loss: 1.1624065461568533


Epoch 29 Loss: 1.012780145360535 OOMs: 0: 100%|██████████| 542/542 [01:27<00:00,  6.17it/s]
100%|██████████| 80/80 [00:04<00:00, 18.61it/s]


val loss: 1.250149608962238


Epoch 30 Loss: 1.0179835860909572 OOMs: 0: 100%|██████████| 542/542 [01:28<00:00,  6.16it/s]
100%|██████████| 80/80 [00:04<00:00, 18.49it/s]


val loss: 1.1739130143076182


Epoch 31 Loss: 1.0022155696395578 OOMs: 0: 100%|██████████| 542/542 [01:27<00:00,  6.16it/s]
100%|██████████| 80/80 [00:04<00:00, 18.51it/s]


val loss: 1.1342084374278785


Epoch 32 Loss: 0.9931010201958272 OOMs: 0: 100%|██████████| 542/542 [01:27<00:00,  6.21it/s]
100%|██████████| 80/80 [00:04<00:00, 18.62it/s]


val loss: 1.174108861386776


Epoch 33 Loss: 0.9819342432976649 OOMs: 0: 100%|██████████| 542/542 [01:27<00:00,  6.18it/s]
100%|██████████| 80/80 [00:04<00:00, 18.63it/s]


val loss: 1.1533867226913572


Epoch 34 Loss: 0.9829506298932642 OOMs: 0: 100%|██████████| 542/542 [01:27<00:00,  6.18it/s]
100%|██████████| 80/80 [00:04<00:00, 18.33it/s]


val loss: 1.1285639091394841


Epoch 35 Loss: 0.9744169779424298 OOMs: 0: 100%|██████████| 542/542 [01:27<00:00,  6.18it/s]
100%|██████████| 80/80 [00:04<00:00, 18.44it/s]


val loss: 1.125099384598434


Epoch 36 Loss: 0.9969339949728379 OOMs: 0: 100%|██████████| 542/542 [01:28<00:00,  6.14it/s]
100%|██████████| 80/80 [00:04<00:00, 18.53it/s]


val loss: 1.1667677629739046


Epoch 37 Loss: 1.0161577302866316 OOMs: 0: 100%|██████████| 542/542 [01:28<00:00,  6.11it/s]
100%|██████████| 80/80 [00:04<00:00, 18.27it/s]


val loss: 1.1152449915185572


Epoch 38 Loss: 1.0023827074418648 OOMs: 0: 100%|██████████| 542/542 [01:27<00:00,  6.16it/s]
100%|██████████| 80/80 [00:04<00:00, 18.54it/s]


val loss: 1.1527654878795146


Epoch 39 Loss: 0.9879281921140383 OOMs: 0: 100%|██████████| 542/542 [01:29<00:00,  6.05it/s]
100%|██████████| 80/80 [00:04<00:00, 18.51it/s]


val loss: 1.1186880856752395


Epoch 40 Loss: 0.9705443411955534 OOMs: 0: 100%|██████████| 542/542 [01:29<00:00,  6.08it/s]
100%|██████████| 80/80 [00:04<00:00, 18.58it/s]


val loss: 1.1005441220477223


Epoch 41 Loss: 0.961172249181904 OOMs: 0: 100%|██████████| 542/542 [01:27<00:00,  6.17it/s]
100%|██████████| 80/80 [00:04<00:00, 18.00it/s]


val loss: 1.09687270373106


Epoch 42 Loss: 0.9438104247914909 OOMs: 0: 100%|██████████| 542/542 [01:29<00:00,  6.07it/s]
100%|██████████| 80/80 [00:04<00:00, 18.57it/s]


val loss: 1.1181462418287993


Epoch 43 Loss: 0.935250915724413 OOMs: 0: 100%|██████████| 542/542 [01:29<00:00,  6.08it/s]
100%|██████████| 80/80 [00:04<00:00, 18.50it/s]


val loss: 1.1200290346518158


Epoch 44 Loss: 0.9242337950646218 OOMs: 0: 100%|██████████| 542/542 [01:28<00:00,  6.12it/s]
100%|██████████| 80/80 [00:04<00:00, 18.24it/s]


val loss: 1.141809639427811


Epoch 45 Loss: 0.9135305027020374 OOMs: 0: 100%|██████████| 542/542 [01:28<00:00,  6.12it/s]
100%|██████████| 80/80 [00:04<00:00, 18.40it/s]


val loss: 1.0918872646056115


Epoch 46 Loss: 0.9110561373653887 OOMs: 0: 100%|██████████| 542/542 [01:28<00:00,  6.12it/s]
100%|██████████| 80/80 [00:04<00:00, 18.14it/s]


val loss: 1.0845405689440668


Epoch 47 Loss: 0.9170929905838192 OOMs: 0: 100%|██████████| 542/542 [01:29<00:00,  6.05it/s]
100%|██████████| 80/80 [00:04<00:00, 18.46it/s]


val loss: 1.1429063846357166


Epoch 48 Loss: 0.8837888142939423 OOMs: 0: 100%|██████████| 542/542 [01:28<00:00,  6.14it/s]
100%|██████████| 80/80 [00:04<00:00, 18.61it/s]


val loss: 1.1016919179819524


Epoch 49 Loss: 0.8948839230675979 OOMs: 0: 100%|██████████| 542/542 [01:28<00:00,  6.10it/s]
100%|██████████| 80/80 [00:04<00:00, 18.42it/s]


val loss: 1.107214194536209


Epoch 50 Loss: 0.8988857603051126 OOMs: 0: 100%|██████████| 542/542 [01:29<00:00,  6.05it/s]
100%|██████████| 80/80 [00:04<00:00, 18.61it/s]


val loss: 1.1127295522019267
