In [1]:
import pandas as pd
import numpy as np
import sys
sys.path.insert(0, '/home/mm22d016/TransPolymer')
import torch
from sklearn.preprocessing import OneHotEncoder
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AdamW, get_linear_schedule_with_warmup, RobertaModel, RobertaConfig, RobertaTokenizer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import torchmetrics
from rdkit import Chem
from torchmetrics import R2Score
from PolymerSmilesTokenization import PolymerSmilesTokenizer
from copy import deepcopy

In [2]:
class Downstream_Dataset(Dataset):
    def __init__(self, dataset, tokenizer, max_token_len):
        self.tokenizer = tokenizer
        self.dataset = dataset
        self.max_token_len = max_token_len

    def __len__(self):
        self.len = len(self.dataset)
        return self.len

    def __getitem__(self, i):
        data_row = self.dataset.iloc[i]
        sd1 = data_row[0]
        sd2 = data_row[1]
        
        anion_smile = data_row[2]
        unique_AmSM = np.unique(data['Anion Smiles'].values)
        one_hot_enocder =  OneHotEncoder(categories=[unique_AmSM], sparse=False)
        one_hot_enocder.fit([[smiles] for smiles in unique_AmSM])
        anion_smile_encoding = torch.tensor(one_hot_enocder.transform([[anion_smile]])[0].astype(int))
        
        log_li = data_row[3]
        comonomer_percentage = data_row[4]/100
        approxMW = data_row[5]
        approxTg = data_row[6]
        chain_architecture = data_row[7]
        new_feats = data_row[8:self.dataset.shape[1]-1]
    
        prop = data_row[self.dataset.shape[1]-1]

        encoding1 = self.tokenizer(
            str(sd1),
            add_special_tokens=True,
            max_length=self.max_token_len,
            return_token_type_ids=False,
            padding="max_length",
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
        )

        encoding2 = self.tokenizer(
            str(sd2),
            add_special_tokens=True,
            max_length=self.max_token_len,
            return_token_type_ids=False,
            padding="max_length",
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
        )
        
        

        return dict(
            input_ids1=encoding1["input_ids"].flatten(),
            attention_mask1=encoding1["attention_mask"].flatten(),
            input_ids2=encoding2["input_ids"].flatten(),
            attention_mask2=encoding2["attention_mask"].flatten(),
            anion_smile_encoding = anion_smile_encoding,
            log_li = torch.tensor(log_li),
            comonomer_percentage = torch.tensor(comonomer_percentage),
            approxMW = torch.tensor(approxMW),
            approxTg = torch.tensor(approxTg),
            chain_architecture = torch.tensor(chain_architecture),
            new_feats = torch.tensor(new_feats),
            prop=prop
    )
    
class DataAugmentation:
    def __init__(self, aug_indicator):
        super(DataAugmentation, self).__init__()
        self.aug_indicator = aug_indicator

    """Rotate atoms to generate more SMILES"""
    def rotate_atoms(self, li, x):
        return (li[x % len(li):] + li[:x % len(li)])

    """Generate SMILES"""
    def generate_smiles(self, smiles):
        smiles_list = []
        try:
            mol = Chem.MolFromSmiles(smiles)
        except:
            mol = None
        if mol != None:
            n_atoms = mol.GetNumAtoms()
            n_atoms_list = [nat for nat in range(n_atoms)]
            if n_atoms != 0:
                for iatoms in range(n_atoms):
                    n_atoms_list_tmp = self.rotate_atoms(n_atoms_list, iatoms)  # rotate atoms' index
                    nmol = Chem.RenumberAtoms(mol, n_atoms_list_tmp)  # renumber atoms in mol
                    try:
                        smiles = Chem.MolToSmiles(nmol,
                                                  isomericSmiles=True,  # keep isomerism
                                                  kekuleSmiles=False,  # kekulize or not
                                                  rootedAtAtom=-1,  # default
                                                  canonical=False,  # canonicalize or not
                                                  allBondsExplicit=False,  #
                                                  allHsExplicit=False)  #
                    except:
                        smiles = 'None'
                    smiles_list.append(smiles)
            else:
                smiles = 'None'
                smiles_list.append(smiles)
        else:
            try:
                smiles = Chem.MolToSmiles(mol,
                                          isomericSmiles=True,  # keep isomerism
                                          kekuleSmiles=False,  # kekulize or not
                                          rootedAtAtom=-1,  # default
                                          canonical=False,  # canonicalize or not
                                          allBondsExplicit=False,  #
                                          allHsExplicit=False)  #
            except:
                smiles = 'None'
            smiles_list.append(smiles)
        smiles_array = pd.DataFrame(smiles_list).drop_duplicates().values
        # """
        if self.aug_indicator is not None:
            smiles_aug = smiles_array[1:, :]
            np.random.shuffle(smiles_aug)
            smiles_array = np.vstack((smiles_array[0, :], smiles_aug[:self.aug_indicator-1, :]))
        return smiles_array

    """SMILES Augmentation"""
    def smiles_augmentation(self, df):
        column_list = df.columns
        data_aug = np.zeros((1, df.shape[1]))
        for i in range(df.shape[0]):
            smiles = df.iloc[i, 0]
            prop = df.iloc[i, 1:]
            smiles_array = self.generate_smiles(smiles)
            if 'None' not in smiles_array:
                prop = np.tile(prop, (len(smiles_array), 1))
                data_new = np.hstack((smiles_array, prop))
                data_aug = np.vstack((data_aug, data_new))
        df_aug = pd.DataFrame(data_aug[1:, :], columns=column_list)
        return df_aug

    """Used for copolymers with two repeating units"""
    def smiles_augmentation_2(self, df):
        df_columns = df.columns
        column_list = df.columns.tolist()
        column_list_temp = deepcopy(column_list)
        column_list_temp[0] = column_list[1]
        column_list_temp[1] = column_list[0]
        df = df[column_list_temp]
        data_aug = np.zeros((1, df.shape[1]))
        for i in range(df.shape[0]):
            if df.loc[i, "Comonomer percentage"] == 100.0:
                data_new = df.values[i, :].reshape(1, -1)
                data_aug = np.vstack((data_aug, data_new))
            else:
                smiles = df.iloc[i, 0]
                prop = df.iloc[i, 1:]
                smiles_array = self.generate_smiles(smiles)
                if 'None' not in smiles_array:
                    prop = np.tile(prop, (len(smiles_array), 1))
                    data_new = np.hstack((smiles_array, prop))
                    data_aug = np.vstack((data_aug, data_new))
        data_aug_copy = deepcopy(data_aug)
        data_aug_copy[:, 0] = data_aug[:, 1]
        data_aug_copy[:, 1] = data_aug[:, 0]
        df_aug = pd.DataFrame(data_aug_copy[1:, :], columns=df_columns)
        return df_aug
    
    def combine_smiles(self, df):
        for i in range(df.shape[0]):
            if df.loc[i, "Comonomer percentage"] != 100.0:
                df.loc[i, "SMILES descriptor 1"] = df.loc[i, "SMILES descriptor 1"] + '.' + df.loc[
                    i, "SMILES descriptor 2"]
        df = df.drop(columns=['SMILES descriptor 2'])
        return df

In [3]:
class DownstreamRegression(nn.Module):
    def __init__(self, drop_rate=0.1):
        super(DownstreamRegression, self).__init__()
        self.PretrainedModel = deepcopy(PretrainedModel)
        self.PretrainedModel.resize_token_embeddings(len(tokenizer))
        
        self.smile_autoencoder = nn.Sequential(
            nn.Dropout(drop_rate),
            nn.Linear(self.PretrainedModel.config.hidden_size, self.PretrainedModel.config.hidden_size),
            nn.SiLU(),
            nn.Linear(self.PretrainedModel.config.hidden_size, 128)
        )
        
        self.Regressor = nn.Sequential(
            nn.Dropout(drop_rate),
            nn.Linear(211, 211),
            nn.SiLU(),
            nn.Linear(211, 1)
        )

        
        
    def forward(self, input_ids1, attention_mask1, input_ids2, attention_mask2, anion_smile_encoding, log_li, comonomer_percentage, approxMW, approxTg, chain_architecture, new_feats):
        outputs1 = self.PretrainedModel(input_ids=input_ids1, attention_mask=attention_mask1)
        logits1 = outputs1.last_hidden_state[:, 0, :]
        outputs2 = self.PretrainedModel(input_ids=input_ids2, attention_mask=attention_mask2)
        logits2 = outputs2.last_hidden_state[:, 0, :]
        logits = ((logits1*comonomer_percentage.unsqueeze(1))+((1-comonomer_percentage.unsqueeze(1))*logits2))
        print("logits",logits,logits.shape)
        smile_embeddings = self.smile_autoencoder(logits)
        print("smile_embeddings",smile_embeddings,smile_embeddings.shape)
        input_to_regressor = torch.cat([smile_embeddings, anion_smile_encoding, log_li.unsqueeze(1), approxMW.unsqueeze(1), approxTg.unsqueeze(1), chain_architecture.unsqueeze(1), new_feats], dim=1).clone().detach()
        print("input_to_regressor",input_to_regressor,input_to_regressor.shape)
        output = self.Regressor(input_to_regressor)
        
        return output


def train(model, loss_fn, train_dataloader, device):

    #model.train()

    for step, batch in enumerate(train_dataloader):
        input_ids1 = batch["input_ids1"].to(device)

        attention_mask1 = batch["attention_mask1"].to(device)
        #print("attention_mask1",attention_mask1,attention_mask1.shape)
        input_ids2 = batch["input_ids2"].to(device)
        #print("input_ids2",input_ids2,input_ids2.shape)
        attention_mask2 = batch["attention_mask2"].to(device)
        #print("attention_mask2",attention_mask2,attention_mask2.shape)
        anion_smile_encoding = batch["anion_smile_encoding"].to(device)
        #print("anion_smile_encoding",anion_smile_encoding,anion_smile_encoding.shape)
        log_li = batch["log_li"].to(device)
        #print("log_li",log_li,log_li.shape)
        comonomer_percentage = batch["comonomer_percentage"].to(device)
        #print("comonomer_percentage",comonomer_percentage,comonomer_percentage.shape)
        approxMW = batch["approxMW"].to(device)
        #print("approxMW",approxMW,approxMW.shape)
        approxTg = batch["approxTg"].to(device)
        #print("approxTg",approxTg,approxTg.shape)
        chain_architecture = batch["chain_architecture"].to(device)
        #print("chain_architecture",chain_architecture,chain_architecture.shape)
        new_feats = batch["new_feats"].to(device)
        #print("new_feats",new_feats,new_feats.shape)
        prop = batch["prop"].to(device).float()
        #print("prop",prop,prop.shape)
        outputs = model(input_ids1, attention_mask1, input_ids2, attention_mask2, anion_smile_encoding, log_li, comonomer_percentage, approxMW, approxTg, chain_architecture, new_feats).float()
        print("End of Batch ------------------------------------------------------------------------------------")

        #loss = loss_fn(outputs.squeeze(), prop.squeeze())
        #loss.backward()
        #optimizer.step()
        #scheduler.step()

    return None


In [4]:
def test(model, loss_fn, train_dataloader, test_dataloader, scaler_target, device):

    r2score = R2Score()
    train_loss = 0
    test_loss = 0
    # count = 0
    model.eval()
    with torch.no_grad():
        train_pred, train_true, test_pred, test_true = torch.tensor([]), torch.tensor([]), torch.tensor(
            []), torch.tensor([])

        print("Train_dataloader")
        for step, batch in enumerate(train_dataloader):
            input_ids1 = batch["input_ids1"].to(device)
            attention_mask1 = batch["attention_mask1"].to(device)
            input_ids2 = batch["input_ids2"].to(device)
            attention_mask2 = batch["attention_mask2"].to(device)
            anion_smile_encoding = batch["anion_smile_encoding"].to(device)
            log_li = batch["log_li"].to(device)
            comonomer_percentage = batch["comonomer_percentage"].to(device)
            approxMW = batch["approxMW"].to(device)
            approxTg = batch["approxTg"].to(device)
            chain_architecture = batch["chain_architecture"].to(device)
            new_feats = batch["new_feats"].to_device()
            prop = batch["prop"].to(device).float()
            outputs = model(input_ids1, attention_mask1, input_ids2, attention_mask2, anion_smile_encoding, log_li, comonomer_percentage, approxMW, approxTg, chain_architecture, new_feats).float()
            outputs = torch.from_numpy(scaler_target.inverse_transform(outputs.cpu().reshape(-1, 1)))
            prop = torch.from_numpy(scaler_target.inverse_transform(prop.cpu().reshape(-1, 1)))
            loss = loss_fn(outputs.squeeze(), prop.squeeze())
            train_loss += loss.item() * len(prop)
            train_pred = torch.cat([train_pred.to(device), outputs.to(device)])
            train_true = torch.cat([train_true.to(device), prop.to(device)])
            print("End of Batch ------------------------------------------")

        train_loss = train_loss / len(train_pred.flatten())
        r2_train = r2score(train_pred.flatten().to("cpu"), train_true.flatten().to("cpu")).item()

        print("Test_dataloader")
        for step, batch in enumerate(test_dataloader):
            input_ids1 = batch["input_ids1"].to(device)
            attention_mask1 = batch["attention_mask1"].to(device)
            input_ids2 = batch["input_ids2"].to(device)
            attention_mask2 = batch["attention_mask2"].to(device)
            anion_smile_encoding = batch["anion_smile_encoding"].to(device)
            log_li = batch["log_li"].to(device)
            comonomer_percentage = batch["comonomer_percentage"].to(device)
            approxMW = batch["approxMW"].to(device)
            approxTg = batch["approxTg"].to(device)
            chain_architecture = batch["chain_architecture"].to(device)
            new_feats = batch["new_feats"].to_device()
            prop = batch["prop"].to(device).float()
            outputs = model(input_ids1, attention_mask1, input_ids2, attention_mask2, anion_smile_encoding, log_li, comonomer_percentage, approxMW, approxTg, chain_architecture, new_feats).float()
            outputs = torch.from_numpy(scaler_target.inverse_transform(outputs.cpu().reshape(-1, 1)))
            prop = torch.from_numpy(scaler_target.inverse_transform(prop.cpu().reshape(-1, 1)))
            loss = loss_fn(outputs.squeeze(), prop.squeeze())
            test_loss += loss.item() * len(prop)
            test_pred = torch.cat([test_pred.to(device), outputs.to(device)])
            test_true = torch.cat([test_true.to(device), prop.to(device)])
            print("End of Batch ------------------------------------------")
            
        test_loss = test_loss / len(test_pred.flatten())
        r2_test = r2score(test_pred.flatten().to("cpu"), test_true.flatten().to("cpu")).item()
    
    return None

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vocab_sup = pd.read_csv('../data/vocab/vocab_sup_PE_II.csv', header=None).values.flatten().tolist()
PretrainedModel = RobertaModel.from_pretrained('../ckpt/pretrain.pt')
tokenizer = PolymerSmilesTokenizer.from_pretrained("roberta-base", max_len=411)
tokenizer.add_tokens(vocab_sup)
PretrainedModel.config.hidden_dropout_prob = 0.1
PretrainedModel.config.attention_probs_dropout_prob = 0.1
max_token_len = 411
model = DownstreamRegression(drop_rate=0.1).to(device)
model = model.double()
loss_fn = nn.MSELoss()
torch.cuda.empty_cache()

Some weights of the model checkpoint at ../ckpt/pretrain.pt were not used when initializing RobertaModel: ['lm_head.layer_norm.weight', 'lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.dense.bias', 'lm_head.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaModel were not initialized from the model checkpoint at ../ckpt/pretrain.pt and are newly initialized: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
The tokenizer clas

In [6]:
data = pd.read_csv('../data/new_feat_Li.csv')
data['chain architecture'] = data['chain architecture'].replace({'S_1': 0, 'S_2': 1})
train_data, test_data = train_test_split(data, test_size=0.2, random_state=1)
train_data.reset_index(drop=True, inplace=True)
test_data.reset_index(drop=True, inplace=True)
train_data

Unnamed: 0,SMILES descriptor 1,SMILES descriptor 2,Anion Smiles,log Li:functional group,Comonomer percentage,approxMW(kDa),approxTg,chain architecture,Comonomer 1 CIC1,Comonomer 1 JGI8,...,Comonomer 2 RDF100m,Comonomer 2 SMR_VSA7,Comonomer 2 SlogP_VSA10,Comonomer 2 SlogP_VSA12,Comonomer 2 SlogP_VSA6,Comonomer 2 TDB10s,Comonomer 2 VSA_EState10,Comonomer 2 fr_aryl_methyl,drying vacuum,logCond60
0,*CC(C*)C#N,*CC(C*)C#N,C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,-0.778151,100.0,170.000000,80.000000,0,1.804290,0.000000,...,0.000000e+00,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0,2,-6.690204
1,CC1=C(C)C(COCP(=S)(CO*)CC2=CC=CC=C2)=C(C)C(C)=...,CC1=C(C)C(COCP(=S)(CO*)CC2=CC=CC=C2)=C(C)C(C)=...,C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,-0.255273,100.0,9.600000,171.000000,0,3.010579,0.012467,...,3.508434e+00,69.778911,0.000000,6.037803,36.398202,18.374741,5.976773,2,2,-4.838063
2,*CNC*,*CNC*,C(F)(F)(F)S(=O)(=O)[O-],-1.176091,100.0,122664.000000,-17.328918,0,1.750978,0.000000,...,0.000000e+00,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0,2,-5.541343
3,COCCOCC(C*)O*,*OC(*)=O,C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,-1.477121,50.0,17.000000,-55.000000,1,2.959037,0.000000,...,0.000000e+00,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0,2,-4.528742
4,*CCO*,*CCO*,F[P-](F)(F)(F)(F)F,-1.204120,100.0,4000.000000,-44.000000,0,1.945531,0.000000,...,0.000000e+00,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0,0,-6.656500
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
211,*CCO*,*CCO*,C(F)(F)(F)S(=O)(=O)[O-],-1.079181,100.0,13.207547,-42.000000,0,1.945531,0.000000,...,0.000000e+00,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0,2,-4.394700
212,*CCOCCOCCOCCOCCOCCOCCOCCOCCOCCOCCOC*,*C1OC(=O)OC1*,C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,-1.204120,50.0,8.100000,-35.600000,1,4.858178,0.000950,...,0.000000e+00,0.000000,4.794537,0.000000,0.000000,0.000000,0.000000,0,0,-3.629528
213,*CCCCCC(=O)O*,*CCCCCC(=O)O*,[O-]Cl(=O)(=O)=O,-1.469449,100.0,65.000000,-55.000000,0,2.250977,0.000000,...,6.220000e-93,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0,2,-4.607459
214,CCOCOCCN(*)CC(*)=O,CCOCOCCN(*)CC(*)=O,C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,-1.070581,100.0,3.242800,9.900000,1,2.309918,0.008230,...,1.710878e-01,0.000000,0.000000,0.000000,0.000000,22.410993,0.000000,0,2,-4.822385


In [7]:
test_data

Unnamed: 0,SMILES descriptor 1,SMILES descriptor 2,Anion Smiles,log Li:functional group,Comonomer percentage,approxMW(kDa),approxTg,chain architecture,Comonomer 1 CIC1,Comonomer 1 JGI8,...,Comonomer 2 RDF100m,Comonomer 2 SMR_VSA7,Comonomer 2 SlogP_VSA10,Comonomer 2 SlogP_VSA12,Comonomer 2 SlogP_VSA6,Comonomer 2 TDB10s,Comonomer 2 VSA_EState10,Comonomer 2 fr_aryl_methyl,drying vacuum,logCond60
0,*CCO*,COCCOCCOCC(C*)O*,C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,-1.69897,91.0,277.777778,-55.45,1,1.945531,0.0,...,0.07948876,0.0,0.0,0.0,0.0,19.149043,0.0,0,3,-3.460793
1,CC(CO*)OC(=O)CCCC(*)=O,CC(CO*)OC(=O)CCCC(*)=O,C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,-1.09691,100.0,8.8,-16.0,0,2.357806,0.00823,...,0.3233966,0.0,0.0,0.0,0.0,15.851587,0.0,0,2,-4.661223
2,*CCO*,*CC(COCC=C)O*,C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,-1.422905,76.7,7.9,-69.0,0,1.945531,0.0,...,1.25e-05,12.654956,0.0,0.0,12.654956,0.0,0.0,0,3,-3.652059
3,*CC(C*)C#N,*CC(C*)C#N,C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,-1.079181,100.0,170.0,80.0,0,1.80429,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0,2,-6.735954
4,*CCOC(=O)O*,*CCO*,[B-](F)(F)(F)F,-0.135672,95.2,37.0,-69.0,0,1.625815,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0,2,-4.34
5,*CCOC(=O)O*,*CCO*,C(F)(F)(F)S(=O)(=O)[O-],-0.083089,95.2,37.0,-7.0,0,1.625815,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0,2,-6.7
6,*OCC(COCC=C)OC(=O)COCC(*)=O,*OCC(COCC=C)OC(=O)COCC(*)=O,C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,-1.154902,100.0,8.9,4.0,1,2.112984,0.006173,...,0.03650635,12.654956,0.0,0.0,12.654956,23.910774,0.0,0,2,-4.948854
7,C[Si](*)(CCSCCCC(=O)NCCCN1C=CN=C1)O*,C[Si](*)(CCSCCCC(=O)NCCCN1C=CN=C1)O*,C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,-1.823909,100.0,110943.6,-6.0,1,1.908336,0.009575,...,0.0,18.721007,0.0,11.761885,18.721007,0.0,1.818517,0,3,-6.446174
8,*CCO*,CC(C*)O*,[O-]Cl(=O)(=O)=O,-0.958607,47.0,5.0,-7.0,0,1.945531,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0,2,-5.721043
9,C[Si](*)(CCSCCCC(=O)NCCCN1C=CN=C1)O*,C[Si](*)(CCSCCCCCCCC1=CC=CC=C1)O*,C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,-1.158015,72.0,67760.0,-13.0,1,1.908336,0.009575,...,0.0,35.895287,0.0,11.761885,30.331835,0.0,2.007703,1,3,-5.97141


In [8]:
DataAug = DataAugmentation(None)
train_data = DataAug.smiles_augmentation(train_data)
train_data

In [9]:
train_data = DataAug.smiles_augmentation_2(train_data)
train_data

In [10]:
scaler_feat = StandardScaler()
train_data[['approxMW(kDa)','approxTg','Comonomer 1 CIC1', 'Comonomer 1 JGI8', 'Comonomer 1 MQNs_16','Comonomer 1 NumAromaticRings', 'Comonomer 1 RDF120u', 'Comonomer 1 RDF130m', 'Comonomer 1 RDF65u', 'Comonomer 1 RDF95m', 'Comonomer 1 RDF95u', 'Comonomer 1 SMR_VSA10', 'Comonomer 1 SMR_VSA7', 'Comonomer 1 SlogP_VSA12', 'Comonomer 1 SlogP_VSA6', 'Comonomer 1 fr_aryl_methyl', 'Comonomer 1 qed', 'Comonomer 2 AVP-0', 'Comonomer 2 GATS3m', 'Comonomer 2 GATS5v', 'Comonomer 2 MATS3s', 'Comonomer 2 MATS4c', 'Comonomer 2 MQNs_16', 'Comonomer 2 MQNs_5', 'Comonomer 2 NumAromaticCarbocycles', 'Comonomer 2 NumAromaticRings', 'Comonomer 2 RDF100m', 'Comonomer 2 SMR_VSA7', 'Comonomer 2 SlogP_VSA10', 'Comonomer 2 SlogP_VSA12', 'Comonomer 2 SlogP_VSA6', 'Comonomer 2 TDB10s', 'Comonomer 2 VSA_EState10', 'Comonomer 2 fr_aryl_methyl','drying vacuum']] = scaler_feat.fit_transform(train_data[['approxMW(kDa)','approxTg','Comonomer 1 CIC1', 'Comonomer 1 JGI8', 'Comonomer 1 MQNs_16','Comonomer 1 NumAromaticRings', 'Comonomer 1 RDF120u', 'Comonomer 1 RDF130m', 'Comonomer 1 RDF65u', 'Comonomer 1 RDF95m', 'Comonomer 1 RDF95u', 'Comonomer 1 SMR_VSA10', 'Comonomer 1 SMR_VSA7', 'Comonomer 1 SlogP_VSA12', 'Comonomer 1 SlogP_VSA6', 'Comonomer 1 fr_aryl_methyl', 'Comonomer 1 qed', 'Comonomer 2 AVP-0', 'Comonomer 2 GATS3m', 'Comonomer 2 GATS5v', 'Comonomer 2 MATS3s', 'Comonomer 2 MATS4c', 'Comonomer 2 MQNs_16', 'Comonomer 2 MQNs_5', 'Comonomer 2 NumAromaticCarbocycles', 'Comonomer 2 NumAromaticRings', 'Comonomer 2 RDF100m', 'Comonomer 2 SMR_VSA7', 'Comonomer 2 SlogP_VSA10', 'Comonomer 2 SlogP_VSA12', 'Comonomer 2 SlogP_VSA6', 'Comonomer 2 TDB10s', 'Comonomer 2 VSA_EState10', 'Comonomer 2 fr_aryl_methyl','drying vacuum']])
test_data[['approxMW(kDa)','approxTg','Comonomer 1 CIC1', 'Comonomer 1 JGI8', 'Comonomer 1 MQNs_16','Comonomer 1 NumAromaticRings', 'Comonomer 1 RDF120u', 'Comonomer 1 RDF130m', 'Comonomer 1 RDF65u', 'Comonomer 1 RDF95m', 'Comonomer 1 RDF95u', 'Comonomer 1 SMR_VSA10', 'Comonomer 1 SMR_VSA7', 'Comonomer 1 SlogP_VSA12', 'Comonomer 1 SlogP_VSA6', 'Comonomer 1 fr_aryl_methyl', 'Comonomer 1 qed', 'Comonomer 2 AVP-0', 'Comonomer 2 GATS3m', 'Comonomer 2 GATS5v', 'Comonomer 2 MATS3s', 'Comonomer 2 MATS4c', 'Comonomer 2 MQNs_16', 'Comonomer 2 MQNs_5', 'Comonomer 2 NumAromaticCarbocycles', 'Comonomer 2 NumAromaticRings', 'Comonomer 2 RDF100m', 'Comonomer 2 SMR_VSA7', 'Comonomer 2 SlogP_VSA10', 'Comonomer 2 SlogP_VSA12', 'Comonomer 2 SlogP_VSA6', 'Comonomer 2 TDB10s', 'Comonomer 2 VSA_EState10', 'Comonomer 2 fr_aryl_methyl','drying vacuum']] = scaler_feat.transform(test_data[['approxMW(kDa)','approxTg','Comonomer 1 CIC1', 'Comonomer 1 JGI8', 'Comonomer 1 MQNs_16','Comonomer 1 NumAromaticRings', 'Comonomer 1 RDF120u', 'Comonomer 1 RDF130m', 'Comonomer 1 RDF65u', 'Comonomer 1 RDF95m', 'Comonomer 1 RDF95u', 'Comonomer 1 SMR_VSA10', 'Comonomer 1 SMR_VSA7', 'Comonomer 1 SlogP_VSA12', 'Comonomer 1 SlogP_VSA6', 'Comonomer 1 fr_aryl_methyl', 'Comonomer 1 qed', 'Comonomer 2 AVP-0', 'Comonomer 2 GATS3m', 'Comonomer 2 GATS5v', 'Comonomer 2 MATS3s', 'Comonomer 2 MATS4c', 'Comonomer 2 MQNs_16', 'Comonomer 2 MQNs_5', 'Comonomer 2 NumAromaticCarbocycles', 'Comonomer 2 NumAromaticRings', 'Comonomer 2 RDF100m', 'Comonomer 2 SMR_VSA7', 'Comonomer 2 SlogP_VSA10', 'Comonomer 2 SlogP_VSA12', 'Comonomer 2 SlogP_VSA6', 'Comonomer 2 TDB10s', 'Comonomer 2 VSA_EState10', 'Comonomer 2 fr_aryl_methyl','drying vacuum']])

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  train_data[['approxMW(kDa)','approxTg','Comonomer 1 CIC1', 'Comonomer 1 JGI8', 'Comonomer 1 MQNs_16','Comonomer 1 NumAromaticRings', 'Comonomer 1 RDF120u', 'Comonomer 1 RDF130m', 'Comonomer 1 RDF65u', 'Comonomer 1 RDF95m', 'Comonomer 1 RDF95u', 'Comonomer 1 SMR_VSA10', 'Comonomer 1 SMR_VSA7', 'Comonomer 1 SlogP_VSA12', 'Comonomer 1 SlogP_VSA6', 'Comonomer 1 fr_aryl_methyl', 'Comonomer 1 qed', 'Comonomer 2 AVP-0', 'Comonomer 2 GATS3m', 'Comonomer 2 GATS5v', 'Comonomer 2 MATS3s', 'Comonomer 2 MATS4c', 'Comonomer 2 MQNs_16', 'Comonomer 2 MQNs_5', 'Comonomer 2 NumAromaticCarbocycles', 'Comonomer 2 NumAromaticRings', 'Comonomer 2 RDF100m', 'Comonomer 2 SMR_VSA7', 'Comonomer 2 SlogP_VSA10', 'Comonomer 2 SlogP_VSA12', 'Comonomer 2 S

In [11]:
scaler_target = StandardScaler()
train_data.iloc[:, train_data.shape[1]-1] = scaler_target.fit_transform(train_data.iloc[:, train_data.shape[1]-1].values.reshape(-1, 1))
test_data.iloc[:, test_data.shape[1]-1] = scaler_target.transform(test_data.iloc[:, test_data.shape[1]-1].values.reshape(-1, 1))

In [13]:
test_data

Unnamed: 0,SMILES descriptor 1,SMILES descriptor 2,Anion Smiles,log Li:functional group,Comonomer percentage,approxMW(kDa),approxTg,chain architecture,Comonomer 1 CIC1,Comonomer 1 JGI8,...,Comonomer 2 RDF100m,Comonomer 2 SMR_VSA7,Comonomer 2 SlogP_VSA10,Comonomer 2 SlogP_VSA12,Comonomer 2 SlogP_VSA6,Comonomer 2 TDB10s,Comonomer 2 VSA_EState10,Comonomer 2 fr_aryl_methyl,drying vacuum,logCond60
0,*CCO*,COCCOCCOCC(C*)O*,C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,-1.69897,91.0,-0.387984,-0.805523,1,-0.556557,-0.726758,...,-0.271044,-0.33034,-0.525625,-0.372231,-0.388859,0.823197,-0.324604,-0.223933,1.072718,1.266018
1,CC(CO*)OC(=O)CCCC(*)=O,CC(CO*)OC(=O)CCCC(*)=O,C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,-1.09691,100.0,-0.395315,-0.025342,0,-0.085529,1.33431,...,0.032122,-0.33034,-0.525625,-0.372231,-0.388859,0.589743,-0.324604,-0.223933,0.014695,0.27922
2,*CCO*,*CC(COCC=C)O*,C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,-1.422905,76.7,-0.395339,-1.073494,0,-0.556557,-0.726758,...,-0.36983,0.55607,-0.525625,-0.372231,1.057407,-0.532523,-0.324604,-0.223933,1.072718,1.10879
3,*CC(C*)C#N,*CC(C*)C#N,C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,-1.079181,100.0,-0.390921,1.873198,0,-0.717925,-0.726758,...,-0.369845,-0.33034,-0.525625,-0.372231,-0.388859,-0.532523,-0.324604,-0.223933,0.014695,-1.426285
4,*CCOC(=O)O*,*CCO*,[B-](F)(F)(F)F,-0.135672,95.2,-0.394546,-1.073494,0,-0.921835,-0.726758,...,-0.369845,-0.33034,-0.525625,-0.372231,-0.388859,-0.532523,-0.324604,-0.223933,0.014695,0.543277
5,*CCOC(=O)O*,*CCO*,C(F)(F)(F)S(=O)(=O)[O-],-0.083089,95.2,-0.394546,0.152646,0,-0.921835,-0.726758,...,-0.369845,-0.33034,-0.525625,-0.372231,-0.388859,-0.532523,-0.324604,-0.223933,0.014695,-1.39673
6,*OCC(COCC=C)OC(=O)COCC(*)=O,*OCC(COCC=C)OC(=O)COCC(*)=O,C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,-1.154902,100.0,-0.395312,0.370187,1,-0.365241,0.819043,...,-0.32447,0.55607,-0.525625,-0.372231,1.057407,1.160319,-0.324604,-0.223933,0.014695,0.042777
7,C[Si](*)(CCSCCCC(=O)NCCCN1C=CN=C1)O*,C[Si](*)(CCSCCCC(=O)NCCCN1C=CN=C1)O*,C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,-1.823909,100.0,2.627992,0.172423,1,-0.599053,1.67099,...,-0.369845,0.980964,-0.525625,2.963018,1.750663,-0.532523,1.15981,-0.223933,1.072718,-1.188075
8,*CCO*,CC(C*)O*,[O-]Cl(=O)(=O)=O,-0.958607,47.0,-0.395418,0.152646,0,-0.556557,-0.726758,...,-0.369845,-0.33034,-0.525625,-0.372231,-0.388859,-0.532523,-0.324604,-0.223933,0.014695,-0.591991
9,C[Si](*)(CCSCCCC(=O)NCCCN1C=CN=C1)O*,C[Si](*)(CCSCCCCCCCC1=CC=CC=C1)O*,C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,-1.158015,72.0,1.451109,0.033987,1,-0.599053,1.67099,...,-0.369845,2.183927,-0.525625,2.963018,3.077601,-0.532523,1.314237,2.321829,1.072718,-0.797802


In [14]:
train_dataset = Downstream_Dataset(train_data, tokenizer, 411)
test_dataset = Downstream_Dataset(test_data, tokenizer, 411)

In [15]:
train_dataloader = DataLoader(train_dataset, 2, shuffle=False, num_workers=8)
test_dataloader = DataLoader(test_dataset, 2, shuffle=False, num_workers=8)

In [16]:
train(model, loss_fn, train_dataloader, device)

logits tensor([[-0.5995, -0.5397,  0.0674,  ..., -0.2653, -0.1414,  1.5422],
        [ 0.6032, -2.1316, -0.1725,  ..., -1.5678, -0.5253, -0.3527]],
       device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>) torch.Size([2, 768])
smile_embeddings tensor([[-1.7507e-02,  3.0780e-01, -3.6170e-01, -1.5713e-01,  2.2114e-01,
          3.3442e-03, -1.2810e-01,  2.0819e-02, -2.0039e-02,  1.0930e-01,
         -2.2876e-01, -5.5523e-02, -1.2781e-01,  2.1920e-01, -1.8712e-01,
         -4.1816e-02, -4.1408e-01,  1.9307e-01,  1.7334e-01,  2.9918e-01,
          6.9719e-02,  3.4996e-01,  1.2484e-01, -3.5378e-01,  1.5688e-02,
         -1.4233e-01, -8.6232e-02,  1.9506e-01,  2.1952e-01,  2.3682e-03,
         -1.3299e-01, -2.2821e-01, -1.1878e-01, -1.0824e-01,  2.4553e-01,
          1.1942e-01, -1.8458e-02, -9.9897e-02,  1.3426e-01, -1.3144e-01,
          1.4590e-01,  3.0798e-01,  5.9213e-01, -2.2333e-02, -5.8445e-02,
         -9.7453e-02,  2.2535e-01,  3.7430e-01, -1.3734e-01,  3.1384e-01,
     

smile_embeddings tensor([[-0.0403,  0.2243, -0.2317, -0.0624,  0.2025, -0.0994, -0.0743,  0.1363,
          0.0679,  0.1109, -0.3074,  0.0077,  0.0303,  0.0335, -0.2377, -0.1017,
         -0.3056,  0.1853,  0.0501,  0.4205,  0.1553,  0.2113,  0.1102, -0.4130,
          0.1476, -0.1027, -0.0901,  0.2963,  0.0419,  0.0747,  0.0167, -0.1760,
         -0.1342, -0.0540,  0.3841,  0.0825, -0.1373, -0.0359,  0.1501, -0.1457,
          0.1682,  0.1647,  0.5669,  0.1133,  0.0549, -0.1079,  0.3650,  0.3363,
         -0.1908,  0.3146,  0.4307,  0.1116, -0.0436, -0.1133,  0.2214,  0.1240,
         -0.1689, -0.0899, -0.2309, -0.0463,  0.0827,  0.1406,  0.1869,  0.0743,
          0.1339, -0.0710,  0.3597, -0.1587, -0.0118, -0.1146, -0.0638, -0.0262,
         -0.4249,  0.3877, -0.1336,  0.2301,  0.1767,  0.2679, -0.3269,  0.0444,
         -0.2698,  0.0396,  0.0526,  0.3878,  0.0350, -0.2360, -0.1761,  0.0853,
          0.1199,  0.1017,  0.3797,  0.3076, -0.1910,  0.2944,  0.2518, -0.1197,
          0

logits tensor([[ 0.6099, -1.1832, -0.4592,  ..., -1.0539, -0.0378,  0.8072],
        [ 0.1294, -1.8128, -0.1654,  ..., -1.4135,  0.1382,  0.7863]],
       device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>) torch.Size([2, 768])
smile_embeddings tensor([[-0.1566,  0.0357, -0.2511, -0.2987,  0.1182,  0.1437, -0.2654, -0.1449,
         -0.1384,  0.0995,  0.0253,  0.1021, -0.2399,  0.0627,  0.0063, -0.1191,
         -0.2659,  0.0425, -0.0231,  0.2825, -0.1453,  0.0248,  0.1832, -0.2609,
          0.1144, -0.0761, -0.0394,  0.0715,  0.0638, -0.1975,  0.1701, -0.1350,
         -0.1756,  0.2026,  0.0697,  0.0861,  0.2725, -0.0073, -0.0028, -0.0750,
          0.1087, -0.0200,  0.0416, -0.0632,  0.0061, -0.1194, -0.0631,  0.1175,
         -0.0535, -0.0375,  0.1237, -0.1197,  0.0541, -0.2592,  0.2299,  0.1310,
         -0.0786, -0.3012, -0.0227,  0.1082,  0.2599, -0.0070, -0.0607, -0.1256,
          0.2430,  0.0901,  0.1593, -0.0701, -0.2133, -0.2954, -0.2642, -0.2045,
          0.0169

logits tensor([[ 0.8386, -0.5247,  0.1221,  ..., -0.0886, -0.1060,  2.2836],
        [-0.4638, -0.8124,  0.1468,  ..., -0.4924,  0.4880,  1.8449]],
       device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>) torch.Size([2, 768])
smile_embeddings tensor([[-0.0492, -0.0556, -0.0903, -0.0189,  0.0486,  0.0387, -0.1742, -0.0595,
          0.0290,  0.1070, -0.0327, -0.0252, -0.1094,  0.1928, -0.3529, -0.2534,
         -0.3234,  0.1900, -0.2032,  0.3310, -0.0056,  0.2842,  0.3372, -0.4240,
         -0.0794, -0.0254, -0.0709,  0.0437,  0.0401,  0.0378,  0.1311, -0.2328,
          0.2076,  0.2239,  0.2800,  0.0378, -0.1267, -0.1358,  0.2826, -0.0403,
          0.3549,  0.1863,  0.3543, -0.0153, -0.2933, -0.3421,  0.1342,  0.5220,
          0.3143,  0.0397,  0.1550,  0.0464, -0.0455, -0.1974,  0.0428,  0.1946,
         -0.3277, -0.0926, -0.2478,  0.2845, -0.0130,  0.2969,  0.2560, -0.0593,
          0.2506,  0.1290,  0.2943, -0.1713, -0.0746, -0.3298, -0.5100, -0.0694,
         -0.3147

logits tensor([[-0.2201, -1.2003, -0.1129,  ..., -0.4003,  0.5902,  1.1035],
        [ 0.1387, -1.5587, -0.0978,  ..., -0.2143, -0.2845,  0.7609]],
       device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>) torch.Size([2, 768])
smile_embeddings tensor([[-2.4673e-02,  3.7002e-01, -4.1227e-01,  1.8881e-02,  2.4161e-01,
         -9.0571e-02,  4.9651e-02, -4.6838e-02, -8.1031e-02,  1.0413e-01,
         -2.1596e-01,  5.8511e-02, -3.1360e-02,  1.1099e-01, -2.3279e-01,
          1.9666e-02, -2.6774e-01,  2.5359e-01,  1.9786e-01,  5.1087e-01,
          1.6142e-01,  3.9306e-01,  1.2489e-01, -3.7283e-01,  1.2591e-01,
          2.4763e-02, -3.7741e-02,  1.8716e-01,  1.1234e-01,  1.0159e-01,
         -7.5299e-02, -1.4351e-01, -4.5999e-02, -1.1993e-01,  2.7004e-01,
          9.9644e-02, -1.6272e-01, -1.5030e-01,  5.5808e-02, -3.5710e-02,
          8.8550e-02,  2.7366e-01,  5.1305e-01,  1.4478e-01, -8.0500e-02,
         -1.7253e-01,  1.3414e-01,  2.9143e-01, -2.0097e-01,  2.6030e-01,
     

input_to_regressor tensor([[-1.3136e-01,  4.4812e-01, -3.4339e-01,  9.2829e-02,  2.8123e-01,
         -1.6990e-01,  3.4953e-02,  2.4425e-03, -1.1754e-01,  2.2298e-01,
         -2.0866e-01,  6.7828e-03, -6.7433e-02,  1.7130e-01, -3.4872e-01,
         -6.3040e-02, -3.1733e-01,  9.5329e-02,  2.5488e-01,  4.9422e-01,
          2.0005e-01,  3.4296e-01,  1.9228e-01, -3.3405e-01,  3.7682e-02,
         -1.0071e-01, -5.7783e-02,  2.5503e-01,  1.0776e-01,  3.2352e-02,
         -1.9306e-01, -1.1485e-01, -1.0540e-01, -6.1962e-02,  2.1109e-01,
          8.3152e-02, -1.2932e-01,  4.5825e-04, -5.0640e-02, -1.0429e-02,
          1.9086e-01,  1.8758e-01,  5.1408e-01,  1.3869e-01, -1.8196e-01,
         -1.7506e-01,  2.1566e-01,  3.4149e-01, -1.9001e-01,  2.3395e-01,
          3.5984e-01,  8.1950e-02,  2.3532e-01, -2.2235e-02,  3.9660e-01,
         -6.0879e-02,  6.5293e-02, -2.1118e-01, -7.8444e-02, -1.2586e-01,
         -8.7110e-02, -4.8175e-02,  5.6685e-02,  1.1894e-01,  2.3239e-01,
         -1.8819e-0

logits tensor([[-0.2966, -0.8822, -0.8697,  ...,  0.3624,  0.5809,  1.6490],
        [ 0.5468, -1.0748, -1.0999,  ...,  0.6221,  0.4043,  1.1871]],
       device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>) torch.Size([2, 768])
smile_embeddings tensor([[-6.4953e-02,  4.8949e-01, -1.6785e-01,  3.2161e-02,  1.6662e-01,
         -5.2399e-02, -2.2948e-01,  4.9182e-02,  4.6450e-02,  8.7956e-02,
         -1.5255e-01, -1.7497e-01, -1.1554e-01,  1.2573e-01, -1.0433e-01,
         -5.4105e-02, -1.8230e-01,  1.3740e-01, -2.3481e-01,  1.3187e-01,
         -1.2519e-02,  4.3603e-01,  2.1724e-01, -1.5350e-01,  1.1973e-01,
          1.2409e-01, -6.9381e-02,  2.1468e-01,  1.4231e-01,  4.8021e-04,
         -2.2339e-02, -4.3038e-01,  1.3131e-01, -1.7203e-01,  3.9544e-01,
          4.4817e-02, -1.9148e-01,  2.6339e-03,  4.3857e-02,  5.7678e-02,
          8.8178e-02,  2.9729e-01,  6.5992e-01, -1.9761e-01, -2.2707e-01,
         -3.1928e-01,  4.1451e-01,  3.2538e-01, -1.2810e-01,  2.5703e-01,
     

smile_embeddings tensor([[ 2.6110e-02,  3.3841e-01, -3.8438e-01, -1.4106e-01,  1.5119e-01,
          1.0764e-02, -7.4311e-02, -9.0694e-02, -5.9160e-02,  3.1053e-01,
         -9.2098e-02,  9.1972e-02, -3.7672e-01,  1.6416e-01, -1.7707e-01,
          1.7194e-02, -1.0955e-02,  1.8065e-01, -3.9212e-02,  2.8285e-01,
         -1.6238e-01,  1.3987e-01,  2.3814e-01, -3.9027e-01,  1.4012e-01,
         -5.7903e-03, -1.0225e-01,  1.4737e-01, -5.6679e-02, -8.6185e-02,
         -2.6402e-01, -2.1550e-01, -2.2102e-02, -1.3512e-01,  2.1052e-01,
          1.3685e-01,  2.6300e-01, -2.2860e-01,  1.2296e-02, -5.8325e-02,
          1.3746e-01,  2.3990e-01,  1.0450e-01, -1.5471e-01, -2.1691e-02,
         -2.5817e-01,  1.0565e-01,  2.8558e-01, -1.2671e-01,  5.5671e-02,
          3.1591e-01,  8.7503e-03,  3.9816e-02, -3.7459e-01,  1.4363e-01,
         -4.1231e-02, -5.7326e-03, -2.9868e-01,  4.7993e-02, -1.3426e-02,
          9.5967e-02, -8.6829e-02,  2.7247e-02,  8.7065e-02,  2.2592e-01,
         -3.6654e-01,

logits tensor([[ 0.8486, -1.6052, -0.8380,  ...,  0.6216,  0.1650,  0.7248],
        [ 0.8386, -0.5247,  0.1221,  ..., -0.0886, -0.1060,  2.2836]],
       device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>) torch.Size([2, 768])
smile_embeddings tensor([[-0.1159,  0.3269, -0.2917, -0.1031,  0.1105,  0.2394, -0.2140, -0.0875,
          0.0579, -0.0331, -0.1139, -0.1057, -0.0074, -0.1340,  0.0294, -0.0781,
         -0.2806,  0.0762, -0.2677,  0.2432, -0.0659, -0.1299,  0.1461, -0.4965,
          0.0328, -0.0527,  0.0134,  0.4222,  0.2307, -0.1624,  0.1706, -0.0522,
          0.0055, -0.1006,  0.2999, -0.0406,  0.0164,  0.0233,  0.0674, -0.1863,
         -0.0240, -0.0932,  0.3480,  0.0790,  0.0472, -0.0879,  0.1747,  0.3995,
         -0.2267,  0.0564, -0.0395, -0.0077,  0.0502, -0.1061,  0.1526,  0.0423,
         -0.4715, -0.2413, -0.1921,  0.0709,  0.0310, -0.1103,  0.1363, -0.1107,
          0.3671,  0.1086,  0.2457, -0.2047, -0.1778, -0.1809, -0.0645, -0.0923,
         -0.0723

input_to_regressor tensor([[-1.2385e-01,  4.9879e-01, -2.4156e-01,  7.0902e-02,  1.3939e-01,
          1.1502e-01, -2.5238e-01,  3.8559e-02,  1.5336e-01,  9.0845e-02,
         -1.4758e-01, -4.3473e-02, -6.4765e-02,  5.0734e-02, -1.2470e-01,
         -2.1940e-03, -2.8899e-01,  1.2026e-01, -2.8930e-01,  5.7729e-02,
         -4.4826e-02,  3.1044e-01,  1.9437e-01, -1.4843e-01,  7.1471e-02,
          7.5212e-03, -8.6659e-02,  3.3099e-01,  1.5164e-01,  2.0122e-02,
         -1.9501e-02, -4.5147e-01,  2.7397e-01, -1.4544e-01,  4.4944e-01,
          1.0654e-01, -2.0888e-01, -4.9866e-02,  3.7120e-02,  2.1509e-03,
          1.0291e-02,  4.1225e-01,  6.3711e-01, -9.4337e-02, -2.1496e-01,
         -2.0013e-01,  2.2310e-01,  4.6091e-01, -1.4153e-01,  1.6107e-01,
          3.2764e-01,  8.0298e-02, -5.8369e-02,  4.2076e-02,  2.7533e-01,
          2.2979e-01, -2.8754e-01, -4.9536e-02, -2.3519e-01, -9.0529e-02,
         -3.7667e-02,  2.4306e-01,  2.7327e-01, -2.8167e-02,  4.5921e-02,
         -1.7642e-0

logits tensor([[ 0.6090, -1.0322, -0.6404,  ..., -1.4557, -0.5946,  0.1906],
        [ 0.5186, -1.1011, -0.3905,  ..., -0.8288,  0.2258,  0.9275]],
       device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>) torch.Size([2, 768])
smile_embeddings tensor([[-0.1229,  0.1329, -0.1747, -0.4924,  0.1244,  0.3685, -0.1461, -0.1123,
          0.1383,  0.1287,  0.0518,  0.2471, -0.2588,  0.0447, -0.1914, -0.1465,
         -0.1655,  0.0196, -0.1128,  0.2275, -0.2647, -0.1269,  0.1616, -0.3143,
          0.2405, -0.0021, -0.1858,  0.1518,  0.1816, -0.4025,  0.0525,  0.0144,
         -0.0690, -0.0073,  0.0371,  0.1027,  0.4258, -0.1651,  0.1769, -0.0628,
          0.0404,  0.0620, -0.0496, -0.2139, -0.1021, -0.1566, -0.0310,  0.1758,
         -0.0117,  0.1320,  0.0972, -0.1309,  0.0375, -0.3263,  0.2534,  0.1030,
         -0.2355, -0.2582,  0.0350, -0.0267,  0.1334,  0.0523, -0.1528, -0.3665,
          0.3572,  0.1636,  0.0548, -0.0074, -0.2792, -0.4066, -0.2892, -0.1657,
         -0.1134

logits tensor([[-0.2040, -1.0950, -0.1784,  ..., -0.4692,  0.6279,  1.3970],
        [-0.2966, -0.8822, -0.8697,  ...,  0.3624,  0.5809,  1.6490]],
       device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>) torch.Size([2, 768])
smile_embeddings tensor([[-1.1886e-01,  4.6715e-01, -3.9242e-01, -5.6648e-02,  2.3051e-01,
         -7.9850e-02, -1.2909e-01,  1.2698e-01,  2.9278e-04,  1.9866e-01,
         -2.0503e-01, -7.6236e-02, -1.3406e-01,  1.0818e-01, -1.9285e-01,
         -2.1026e-01, -2.6744e-01,  7.7921e-02, -2.6004e-02,  3.6493e-01,
          7.7501e-02,  4.6728e-01,  1.8197e-01, -3.3289e-01,  5.1328e-02,
         -5.1552e-02, -7.6283e-02,  2.1816e-01,  7.1939e-02,  1.4466e-01,
         -9.7954e-02, -3.2622e-01,  1.7319e-01, -1.2481e-02,  2.5172e-01,
          9.9632e-02, -2.2241e-01, -6.5655e-02,  8.5266e-02, -8.6823e-02,
          1.2719e-02,  4.1718e-01,  5.9798e-01, -3.7007e-02, -1.3420e-01,
         -3.8914e-02,  2.3205e-01,  3.8203e-01, -2.2422e-01,  3.3961e-01,
     

logits tensor([[ 0.6032, -2.1316, -0.1725,  ..., -1.5678, -0.5253, -0.3527],
        [-0.2252, -1.2326, -0.0553,  ..., -0.8273,  0.6900,  1.5285]],
       device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>) torch.Size([2, 768])
smile_embeddings tensor([[ 0.0570,  0.1315, -0.0359, -0.1528, -0.0666,  0.4331, -0.1759, -0.0366,
          0.1613,  0.0674, -0.0428,  0.1146,  0.0328,  0.2268, -0.0793, -0.1074,
         -0.2651, -0.1528, -0.3410,  0.1850, -0.2020, -0.1346,  0.2593, -0.1410,
          0.1251, -0.1465, -0.2690, -0.0624,  0.0277, -0.1751,  0.0348, -0.2200,
         -0.2585, -0.0749, -0.0008,  0.1900,  0.3133, -0.2829, -0.0393,  0.0411,
          0.2425,  0.0101,  0.1129, -0.2097,  0.0711, -0.1552,  0.0955,  0.2626,
         -0.0232,  0.0479,  0.1276, -0.1195,  0.1037, -0.2290, -0.0103,  0.0017,
         -0.1378, -0.1894, -0.1412,  0.1664,  0.0849, -0.0700,  0.0380, -0.5302,
          0.3029,  0.0559,  0.0564,  0.0970, -0.0265, -0.1760, -0.4358, -0.2926,
         -0.3366

logits tensor([[-0.2097, -1.2207, -0.0934,  ..., -0.4770,  0.6482,  1.0791],
        [ 0.2715, -1.6248, -0.7729,  ..., -0.3850,  0.2336,  1.5361]],
       device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>) torch.Size([2, 768])
smile_embeddings tensor([[-0.0392,  0.4358, -0.4057, -0.0023,  0.2100, -0.1184, -0.0326,  0.0177,
          0.0326,  0.2623, -0.2321,  0.1048,  0.0303,  0.1713, -0.2078, -0.0939,
         -0.2308,  0.1874,  0.3395,  0.4208,  0.2200,  0.3366,  0.1670, -0.3259,
          0.0966, -0.0058, -0.1217,  0.1605,  0.0595, -0.0172, -0.2049, -0.0635,
         -0.0971, -0.0613,  0.2489,  0.0046, -0.1277, -0.1269, -0.1442, -0.0055,
          0.1100,  0.0933,  0.5831,  0.2195, -0.1603,  0.0608,  0.1281,  0.1838,
         -0.0956,  0.3005,  0.2077,  0.1822,  0.0611, -0.0051,  0.2728, -0.0255,
          0.0609, -0.0900, -0.0676, -0.1544, -0.1192, -0.0142,  0.0333,  0.0287,
          0.0949, -0.3690,  0.4029, -0.1486, -0.0287, -0.0453,  0.1513,  0.0623,
         -0.4214

logits tensor([[-0.1045, -1.9911, -0.1706,  ..., -2.2785, -0.2759,  0.5606],
        [-0.5863, -0.2921, -0.3963,  ...,  0.5901,  0.3074,  1.9838]],
       device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>) torch.Size([2, 768])
smile_embeddings tensor([[ 0.0711,  0.2773, -0.2529, -0.3024,  0.0179,  0.3023, -0.1976, -0.0410,
         -0.1183,  0.1601, -0.1888, -0.0465, -0.0283,  0.2578, -0.2385,  0.0167,
         -0.1575,  0.0281, -0.2835,  0.3358, -0.2938, -0.0110,  0.1583, -0.3457,
          0.1066, -0.1444, -0.0389,  0.0807,  0.2838, -0.0004,  0.2081, -0.1538,
         -0.1332, -0.1957, -0.0743, -0.0668, -0.0556, -0.2730, -0.1343, -0.0190,
          0.2413,  0.1604,  0.2377, -0.2382,  0.1493, -0.0074, -0.0505,  0.1310,
         -0.1825,  0.0584,  0.2207, -0.2301,  0.1840, -0.3300,  0.1192, -0.0364,
         -0.1408, -0.3509,  0.0270,  0.1818,  0.1252, -0.1407,  0.0266, -0.3523,
          0.0260, -0.0043,  0.0512,  0.1258, -0.0524, -0.2687, -0.2551, -0.1455,
         -0.4355

logits tensor([[-0.1306, -0.8486, -0.0083,  ..., -0.7752,  0.7274,  1.6186],
        [-0.3569, -0.8819, -0.8136,  ...,  0.2285,  0.5567,  1.5719]],
       device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>) torch.Size([2, 768])
smile_embeddings tensor([[ 0.1159,  0.3754, -0.2609,  0.0564,  0.2154,  0.0026,  0.0513, -0.1015,
          0.0678,  0.1123, -0.2585, -0.0313, -0.1150,  0.1133, -0.2377,  0.0009,
         -0.2223,  0.1974,  0.0114,  0.2387, -0.0330,  0.2641,  0.1758, -0.2057,
          0.1673, -0.0994, -0.0260,  0.2472,  0.2906, -0.0925, -0.0714, -0.1937,
          0.1181, -0.0888,  0.1811,  0.0019, -0.0300, -0.0470,  0.0991,  0.0848,
          0.1294,  0.1965,  0.4316, -0.0918, -0.0839, -0.0254,  0.1498,  0.1692,
         -0.1912,  0.1804,  0.2563,  0.0921,  0.0147, -0.1317,  0.3267,  0.0404,
          0.0073,  0.0450, -0.0237, -0.1572, -0.0752,  0.0629,  0.1381,  0.2972,
          0.1865, -0.3102,  0.2690, -0.1664, -0.0369, -0.0792,  0.0094, -0.0558,
         -0.3208

input_to_regressor tensor([[-5.9825e-02,  3.5950e-01, -1.9694e-01, -5.0311e-03,  2.2162e-01,
          8.1311e-02, -1.7365e-01,  2.2891e-03,  8.0519e-02, -2.6775e-02,
         -1.4948e-01, -5.1327e-02, -9.0185e-02,  6.7174e-02, -1.6406e-01,
          8.3481e-02, -1.9835e-01,  2.4794e-01, -1.7930e-01,  1.8436e-01,
          3.4201e-02,  3.5836e-01,  1.4236e-01, -2.7714e-01,  2.0932e-01,
          4.4819e-02, -8.4679e-02,  3.3172e-01,  1.6273e-01, -6.3136e-02,
         -3.1513e-02, -3.6942e-01,  7.9748e-02, -3.1983e-02,  3.0112e-01,
          2.3839e-02, -2.4505e-01,  1.3014e-01,  4.8193e-02,  6.1439e-03,
         -5.5380e-03,  3.8717e-01,  6.6794e-01, -8.7460e-02, -2.9179e-01,
         -2.7560e-01,  3.2263e-01,  3.8806e-01, -1.8228e-01,  2.3294e-01,
          3.0007e-01,  9.9386e-02,  2.2254e-02,  5.3506e-02,  3.0591e-01,
          9.4695e-02, -2.1505e-01, -1.1501e-01, -3.0576e-01, -1.9519e-02,
         -4.0852e-03,  1.6203e-01,  2.3841e-01, -6.7804e-02, -2.7922e-03,
         -4.8071e-0

logits tensor([[ 0.4091, -1.9204, -1.4322,  ..., -0.4521, -0.0406,  0.3880],
        [ 0.7302, -1.2915, -0.5496,  ..., -1.3506, -0.3852,  0.6485]],
       device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>) torch.Size([2, 768])
smile_embeddings tensor([[-0.0045,  0.2286, -0.2663, -0.2119,  0.0256,  0.1777, -0.1949,  0.0792,
         -0.0409,  0.0131, -0.0389,  0.2078, -0.1535,  0.0802, -0.2038, -0.2631,
         -0.4176,  0.2435, -0.1874,  0.3353, -0.1860, -0.0046,  0.1924, -0.5835,
          0.1198, -0.0231, -0.1750,  0.1239,  0.3175, -0.0274,  0.0899, -0.0462,
          0.0424,  0.0143,  0.1449,  0.1223,  0.2934, -0.0399,  0.0418, -0.0356,
         -0.0025,  0.0126, -0.0096, -0.0939,  0.1676, -0.1318,  0.0600,  0.1329,
         -0.4065,  0.1682,  0.0631,  0.1266,  0.0184, -0.2757,  0.4108,  0.0349,
         -0.1783, -0.4574,  0.1647,  0.0328, -0.0048, -0.0387,  0.1404, -0.2130,
          0.2836,  0.1326,  0.1249, -0.1357, -0.0754, -0.1543,  0.0015, -0.2604,
         -0.0607

input_to_regressor tensor([[-0.1671,  0.5572, -0.2628, -0.0133,  0.3089,  0.1222, -0.1543,  0.1045,
          0.1917, -0.0148, -0.0582,  0.0070, -0.1319,  0.0610, -0.0487, -0.0157,
         -0.2633,  0.1694, -0.1400,  0.1810, -0.1082,  0.3355,  0.1066, -0.1317,
          0.0771,  0.0904, -0.0253,  0.1744,  0.1263,  0.1556, -0.0653, -0.4493,
          0.2483, -0.0723,  0.3499,  0.0280, -0.1485, -0.0234,  0.0640, -0.0450,
          0.0121,  0.3130,  0.4954, -0.2197, -0.0656, -0.3190,  0.2522,  0.3289,
         -0.2073,  0.1090,  0.2954,  0.0271, -0.0514, -0.0325,  0.2850,  0.1691,
         -0.2651,  0.0719, -0.3103, -0.0033, -0.0677,  0.1994,  0.2012,  0.0191,
         -0.0210, -0.1159,  0.1606, -0.1805, -0.1704, -0.1071, -0.1155, -0.1102,
         -0.4840,  0.3094, -0.1430,  0.2195,  0.2510,  0.1893, -0.1811,  0.3095,
          0.0290,  0.0735,  0.0622,  0.2763,  0.1642, -0.0781, -0.0785,  0.0621,
         -0.0126, -0.0930,  0.2168,  0.3691, -0.1014,  0.1407,  0.0832, -0.2434,
         

logits tensor([[-0.2966, -0.8822, -0.8697,  ...,  0.3624,  0.5809,  1.6490],
        [ 0.5468, -1.0748, -1.0999,  ...,  0.6221,  0.4043,  1.1871]],
       device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>) torch.Size([2, 768])
smile_embeddings tensor([[-0.1830,  0.3989, -0.2099,  0.0469,  0.1220, -0.0597, -0.2606, -0.0037,
          0.1315,  0.0288, -0.0491,  0.0716,  0.0237, -0.0353, -0.0830,  0.0074,
         -0.3185,  0.1655, -0.2219,  0.2553,  0.0124,  0.2431,  0.0695, -0.1683,
          0.1805,  0.1485, -0.0422,  0.1978,  0.0289,  0.0123, -0.1350, -0.3585,
          0.1369, -0.0648,  0.3545,  0.0545, -0.1143,  0.0231, -0.0394,  0.0623,
         -0.0037,  0.4024,  0.5973, -0.1529, -0.0774, -0.2637,  0.3195,  0.3617,
         -0.2607,  0.1538,  0.3116,  0.0683,  0.0235,  0.0726,  0.2432,  0.1080,
         -0.2079,  0.0054, -0.3502,  0.0008, -0.1184,  0.1613,  0.2684,  0.0255,
         -0.0333, -0.0640,  0.2543, -0.1367, -0.1964, -0.1071, -0.0952,  0.0966,
         -0.2839

input_to_regressor tensor([[-5.0451e-02,  3.3891e-01, -2.9545e-01,  7.3052e-02,  1.7951e-01,
          7.8743e-02, -2.3421e-01,  7.7263e-02,  1.7452e-01,  2.4633e-01,
         -1.9933e-01, -5.0928e-02, -1.2732e-01,  3.9576e-02, -1.0648e-01,
         -6.8079e-02, -2.4654e-01,  8.6257e-02, -2.9777e-01,  2.1732e-01,
         -4.3342e-02,  3.5402e-01,  2.2793e-01, -6.7068e-02, -5.5485e-03,
          9.3485e-02, -7.9145e-02,  1.4690e-01,  2.1284e-01, -1.9488e-02,
         -6.8777e-02, -3.8700e-01,  3.6310e-01, -5.4675e-02,  2.2596e-01,
         -9.7299e-02, -1.5500e-01, -5.4810e-02, -1.0636e-01,  2.5112e-02,
          2.9958e-02,  3.4962e-01,  5.7136e-01,  7.8665e-03, -2.3106e-01,
         -1.6054e-01,  3.3269e-01,  4.1448e-01, -9.0922e-02,  1.1191e-01,
          2.2861e-01,  2.9679e-02, -9.5638e-02, -6.4016e-02,  2.8739e-01,
          2.4919e-01, -2.5522e-01, -6.5318e-02, -2.7632e-01, -2.7959e-02,
          6.7636e-02,  2.2845e-01,  2.1247e-01, -1.1427e-02,  3.7230e-02,
         -8.7949e-0

logits tensor([[ 0.3627, -1.2584,  0.1588,  ...,  0.0234, -0.2399,  1.1298],
        [ 0.1387, -1.5587, -0.0978,  ..., -0.2143, -0.2845,  0.7609]],
       device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>) torch.Size([2, 768])
smile_embeddings tensor([[ 0.0216,  0.0902, -0.1461, -0.1725,  0.0712,  0.3340, -0.2587, -0.1834,
         -0.1133,  0.0703,  0.1918, -0.0071, -0.1657,  0.0311, -0.1629, -0.1434,
         -0.2838,  0.0504, -0.2594,  0.0664, -0.1793, -0.2312,  0.3874, -0.2852,
         -0.0281, -0.1835, -0.0994,  0.1843,  0.2142, -0.0113,  0.2930, -0.2243,
         -0.1783,  0.0237,  0.0046,  0.1382, -0.0172, -0.2133, -0.1172, -0.3707,
          0.2895, -0.0601,  0.0557,  0.0193,  0.1030, -0.2189,  0.0034,  0.2649,
          0.0811, -0.0780,  0.3523, -0.1408, -0.0672, -0.2051,  0.3709,  0.1051,
         -0.0886, -0.2179, -0.0281,  0.2424,  0.2168, -0.0843, -0.0049, -0.2288,
          0.1482,  0.1447,  0.1434, -0.1281, -0.0893, -0.1929, -0.1993, -0.2227,
         -0.3244

logits tensor([[-0.2188, -0.9537, -0.9130,  ...,  0.4257,  0.5275,  1.5223],
        [ 0.3443, -0.9444, -0.2595,  ..., -0.3991,  0.7289,  1.1573]],
       device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>) torch.Size([2, 768])
smile_embeddings tensor([[-0.1545,  0.3614, -0.1078, -0.0280,  0.1495,  0.0537, -0.1510, -0.0479,
          0.2633,  0.0517, -0.1163, -0.0731, -0.1741,  0.0438, -0.1471, -0.0047,
         -0.3973,  0.0877, -0.3153,  0.1399, -0.0906,  0.3591,  0.1899, -0.3092,
         -0.0193,  0.0622,  0.0589,  0.4368,  0.1460,  0.0273,  0.0501, -0.4534,
          0.1807, -0.1469,  0.2523, -0.0253, -0.1896,  0.0468,  0.0075,  0.0736,
          0.0927,  0.2976,  0.5329, -0.1951, -0.2453, -0.2126,  0.2834,  0.3440,
         -0.1227,  0.1141,  0.3014,  0.0466,  0.0942, -0.0346,  0.2662,  0.1855,
         -0.2108, -0.0526, -0.3277, -0.0835,  0.0095,  0.2842,  0.2829, -0.0743,
          0.0372, -0.1280,  0.1607, -0.1966, -0.2423, -0.1553, -0.1256, -0.0294,
         -0.4282

logits tensor([[-0.2201, -1.2003, -0.1129,  ..., -0.4003,  0.5902,  1.1035],
        [-0.8821, -0.8786, -0.3251,  ..., -0.9374,  0.3461,  0.9004]],
       device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>) torch.Size([2, 768])
smile_embeddings tensor([[-3.9366e-02,  3.7993e-01, -2.9442e-01,  1.9424e-03,  2.2167e-01,
         -1.9019e-01, -6.4721e-02,  6.7844e-02,  3.5995e-02,  1.6886e-01,
         -1.9808e-01,  1.0866e-02, -3.8818e-02,  1.5924e-01, -2.9890e-01,
         -2.1692e-02, -1.3758e-01,  3.3810e-01,  2.4204e-01,  3.2265e-01,
          3.1070e-01,  2.9083e-01,  1.5876e-01, -3.6079e-01,  2.5222e-02,
         -1.0256e-02, -9.9543e-02,  3.3917e-01,  4.3766e-02,  5.7011e-02,
         -9.9155e-02, -1.0386e-01,  1.0142e-02, -1.6069e-01,  1.7535e-01,
          3.8982e-02, -1.4811e-01, -6.8262e-02,  7.9692e-02, -7.8987e-02,
          9.2639e-02,  2.6830e-01,  4.8126e-01,  9.0409e-02, -2.1644e-01,
         -1.0026e-01,  1.6433e-01,  3.5420e-01, -1.3533e-01,  2.2868e-01,
     

input_to_regressor tensor([[-9.4252e-02,  4.9110e-01, -2.3651e-01,  1.1394e-02,  1.9511e-01,
          5.8363e-02, -1.4626e-01, -2.6641e-02,  1.0406e-01, -8.3794e-02,
         -9.5996e-02, -4.6323e-02, -1.4841e-01,  7.9346e-02, -1.2584e-01,
          3.9501e-02, -2.9889e-01,  6.2847e-02, -2.3171e-01,  2.0387e-01,
         -8.2152e-02,  3.7202e-01,  1.5764e-01, -2.2091e-01,  5.1388e-02,
          1.2245e-01, -2.4726e-03,  2.3874e-01,  9.8500e-02,  7.8842e-02,
         -3.0302e-02, -3.1191e-01,  1.2136e-01, -2.5435e-02,  3.1586e-01,
          6.4204e-02, -2.0595e-01,  6.8801e-02, -8.9549e-03, -4.7648e-02,
         -7.9460e-03,  3.6919e-01,  5.8591e-01, -1.3536e-01, -1.6539e-01,
         -1.9718e-01,  2.8366e-01,  3.9887e-01, -2.2632e-01,  1.8871e-01,
          2.7762e-01,  1.0676e-01, -5.1385e-05,  4.5092e-02,  2.5656e-01,
          1.2627e-01, -2.3476e-01,  1.8215e-02, -4.1597e-01, -3.5646e-02,
         -4.0535e-02,  2.7244e-01,  1.5916e-01, -1.0181e-01,  2.5181e-02,
         -1.4457e-0

smile_embeddings tensor([[ 0.0253,  0.3459, -0.2529,  0.0343,  0.1668,  0.2519, -0.0760, -0.0524,
          0.0087,  0.1218,  0.0384,  0.0282, -0.1285,  0.0170, -0.0056, -0.0560,
         -0.2513,  0.0162, -0.2027,  0.2686, -0.1753, -0.0380,  0.2618, -0.2561,
          0.1597, -0.0786, -0.1416,  0.2414,  0.1145, -0.0955,  0.2043, -0.0078,
          0.1069,  0.1553,  0.1031,  0.1409,  0.0366, -0.0685, -0.1968, -0.1714,
          0.0303, -0.0335,  0.3446,  0.0809,  0.0039, -0.1139,  0.0884,  0.3007,
         -0.2400, -0.0511, -0.0243, -0.1040, -0.0539, -0.3141,  0.1784,  0.0488,
         -0.1141, -0.0261, -0.1985, -0.0456, -0.1432, -0.0889,  0.0242, -0.1394,
          0.4711,  0.2378,  0.1642, -0.0627, -0.0867, -0.1399, -0.3688, -0.2733,
         -0.1690,  0.1093,  0.0377, -0.2609,  0.0152,  0.1734, -0.2831,  0.0708,
         -0.1831,  0.0698, -0.0166,  0.1190,  0.3202,  0.2288,  0.2457,  0.1367,
          0.1411, -0.0413,  0.2127,  0.0521, -0.1065, -0.0105,  0.0078, -0.2373,
         -0

logits tensor([[-0.5510, -0.7370,  0.1609,  ..., -0.4974,  0.4668,  1.9421],
        [ 0.4200, -1.6721, -1.4597,  ..., -0.1327,  0.1762,  0.5013]],
       device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>) torch.Size([2, 768])
smile_embeddings tensor([[-0.1459,  0.4093, -0.4487, -0.0423,  0.2234, -0.0476, -0.0993,  0.1244,
         -0.0215,  0.2302, -0.2585,  0.1138, -0.0890,  0.1723, -0.1122,  0.0338,
         -0.3212,  0.0725,  0.0036,  0.3523,  0.0915,  0.3999,  0.2570, -0.1693,
          0.0780, -0.0257, -0.1985,  0.1902,  0.0988, -0.0674, -0.0257, -0.2829,
          0.1893, -0.1249,  0.2489,  0.0248, -0.1666, -0.1318,  0.1468, -0.0071,
          0.0591,  0.1839,  0.6339,  0.1423, -0.1247, -0.1637,  0.4201,  0.3442,
         -0.3446,  0.2787,  0.4789,  0.1522,  0.0778,  0.0456,  0.3990,  0.0946,
         -0.1413, -0.0835, -0.1936, -0.0206, -0.1998,  0.2245,  0.1959,  0.0084,
          0.1207, -0.2310,  0.2464, -0.1812,  0.0070, -0.0668,  0.0406, -0.1314,
         -0.3310

logits tensor([[ 0.6032, -2.1316, -0.1725,  ..., -1.5678, -0.5253, -0.3527],
        [-0.1901, -0.8716, -0.7683,  ...,  0.4039,  0.6170,  1.7235]],
       device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>) torch.Size([2, 768])
smile_embeddings tensor([[-4.1026e-02,  2.0824e-01, -5.0033e-02, -3.0835e-01, -1.4516e-01,
          4.0395e-01, -1.8197e-01, -1.7604e-02,  1.2849e-01,  1.7815e-02,
         -1.7091e-02,  8.8241e-02, -3.8869e-02,  2.3755e-01, -7.8752e-03,
          1.2422e-01, -3.6826e-01,  3.7044e-02, -2.1783e-01,  1.0496e-01,
         -4.2938e-03, -2.9934e-02,  3.0720e-01, -9.0649e-02,  2.7535e-01,
         -6.6130e-02, -1.9958e-01,  1.0249e-01,  9.7257e-02, -1.6494e-01,
          3.8548e-03, -2.6024e-01, -1.5298e-01, -3.9668e-02,  7.7818e-04,
          1.5037e-01,  3.0218e-01, -3.2887e-01,  3.1496e-02, -6.3400e-02,
          1.2021e-01, -1.2164e-01,  8.6100e-02, -2.2876e-01,  7.1185e-02,
         -1.3592e-01,  7.7456e-02,  2.3850e-01, -7.1322e-02,  8.3599e-02,
     

logits tensor([[ 0.3443, -0.9444, -0.2595,  ..., -0.3991,  0.7289,  1.1573],
        [ 0.6555, -1.3971, -1.1400,  ...,  0.3161,  0.3425,  0.9145]],
       device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>) torch.Size([2, 768])
smile_embeddings tensor([[-2.0302e-01,  3.3159e-01, -3.2374e-01, -2.5507e-01,  1.6953e-01,
          1.0419e-01, -3.3081e-01, -1.1908e-01, -5.0007e-02,  2.6306e-01,
          6.4569e-02,  9.6846e-03, -4.2541e-01,  7.3419e-02, -1.2858e-01,
         -1.1135e-01, -3.7688e-01,  1.1322e-01, -8.9925e-02,  3.0492e-01,
         -2.9769e-01,  3.1645e-01,  3.2278e-01, -2.7615e-01,  6.9016e-02,
         -1.9588e-01, -1.9552e-01,  4.3468e-02,  2.2833e-01, -1.6676e-01,
          1.2630e-02, -2.0104e-01,  2.7515e-02,  1.2682e-01,  1.9055e-01,
          2.4695e-01,  2.2803e-01, -7.7577e-02,  3.0133e-02, -1.0525e-01,
         -7.7887e-04,  1.4890e-01,  2.6110e-01, -1.1031e-01, -1.9814e-01,
         -2.8810e-01,  2.1433e-01,  2.1485e-01, -1.9673e-01,  3.5676e-02,
     

smile_embeddings tensor([[-0.1310,  0.0694, -0.0201,  0.0366,  0.1145,  0.1369, -0.0225,  0.0522,
          0.0970,  0.0267,  0.0077, -0.0948, -0.1967,  0.0738, -0.1529, -0.3187,
         -0.2719, -0.0277, -0.1920,  0.2287, -0.0852,  0.1569,  0.2613, -0.4094,
          0.0528, -0.0080, -0.0849, -0.0943,  0.1550,  0.0167,  0.0089, -0.1619,
          0.1336,  0.2093,  0.2273,  0.0732, -0.0169, -0.1323,  0.3029, -0.1654,
          0.1106,  0.2112,  0.4746, -0.0201, -0.0543, -0.2271,  0.2700,  0.3373,
          0.2647, -0.0644,  0.2162, -0.0763, -0.0374, -0.1614, -0.0608,  0.2329,
         -0.3551, -0.0642, -0.2893,  0.1025, -0.1069,  0.2664,  0.1133, -0.1695,
          0.2108,  0.0209,  0.2033, -0.0694, -0.1189, -0.1394, -0.4290, -0.0077,
         -0.2525,  0.2399, -0.0359,  0.0348,  0.1744,  0.3075, -0.1372, -0.0530,
         -0.1960,  0.1233,  0.1144,  0.2444,  0.2940,  0.0434,  0.1234,  0.1474,
         -0.0861, -0.1492,  0.1781, -0.0839, -0.2282,  0.1344,  0.2078, -0.1559,
          0

logits tensor([[ 0.3627, -1.2584,  0.1588,  ...,  0.0234, -0.2399,  1.1298],
        [-0.1657, -2.1037,  0.2269,  ..., -1.8627,  0.0621,  1.3463]],
       device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>) torch.Size([2, 768])
smile_embeddings tensor([[ 0.0060,  0.1764, -0.3050, -0.1525, -0.0942,  0.2824, -0.1993, -0.3223,
         -0.1323, -0.0038,  0.2308,  0.0155, -0.2568,  0.1581, -0.0274, -0.0391,
         -0.2332,  0.1227, -0.2495,  0.1417, -0.2830, -0.1801,  0.2386, -0.2756,
         -0.1077, -0.2638, -0.1511,  0.3392,  0.1220, -0.1605,  0.2721, -0.1451,
         -0.2423, -0.0701,  0.1437, -0.0008, -0.0469, -0.1751, -0.0725, -0.4294,
          0.2700, -0.0704,  0.0680,  0.0365,  0.2610, -0.2676, -0.0962,  0.2551,
         -0.0083, -0.0259,  0.3118, -0.2020,  0.0469, -0.2364,  0.3050,  0.1278,
         -0.1360, -0.2225, -0.0421,  0.2656,  0.2413, -0.0897,  0.0797, -0.2630,
          0.2703,  0.0435,  0.0364, -0.0059,  0.0477, -0.0936, -0.3604, -0.0932,
         -0.2333

logits tensor([[-0.5995, -0.5397,  0.0674,  ..., -0.2653, -0.1414,  1.5422],
        [ 0.1387, -1.5587, -0.0978,  ..., -0.2143, -0.2845,  0.7609]],
       device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>) torch.Size([2, 768])
smile_embeddings tensor([[-0.0066,  0.2283, -0.3080, -0.0741,  0.2507,  0.0327, -0.0797, -0.0490,
          0.0195,  0.0210, -0.2370, -0.0056, -0.0379,  0.1307, -0.2036, -0.2213,
         -0.2537,  0.0956,  0.1439,  0.1788,  0.1058,  0.2429,  0.1401, -0.3028,
          0.0587, -0.2378, -0.1211,  0.1976,  0.1908, -0.0091, -0.0871, -0.1845,
          0.0999,  0.0607,  0.2323,  0.1478, -0.1431, -0.2167,  0.2189, -0.0715,
          0.1544,  0.1594,  0.5147,  0.0491, -0.2380, -0.1612,  0.3250,  0.2211,
         -0.0336,  0.3336,  0.3140,  0.0275,  0.0419, -0.1050,  0.3177, -0.0006,
         -0.1204,  0.0254, -0.3012,  0.0581, -0.0515,  0.1854,  0.1711,  0.1227,
          0.0692, -0.2514,  0.4053, -0.1721, -0.0050, -0.1138, -0.0242, -0.0207,
         -0.3122

logits tensor([[-0.2040, -1.0950, -0.1784,  ..., -0.4692,  0.6279,  1.3970],
        [-0.0710, -1.4718, -0.0976,  ..., -0.6046, -0.0276,  0.0330]],
       device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>) torch.Size([2, 768])
smile_embeddings tensor([[-1.7928e-01,  3.6784e-01, -3.2278e-01, -6.1310e-03,  2.4460e-01,
         -2.1999e-01, -1.4548e-01,  1.9719e-01,  2.5916e-03,  1.0807e-01,
         -2.3841e-01,  2.2924e-02, -1.6314e-01,  1.3614e-01, -3.5476e-01,
         -8.1887e-02, -2.7332e-01,  4.1639e-02,  6.5279e-02,  3.6224e-01,
          1.2069e-01,  3.9071e-01,  1.0139e-01, -3.0717e-01,  9.1244e-02,
         -7.9465e-02, -9.7861e-02,  1.2542e-01, -9.3340e-02,  2.6804e-02,
         -4.2560e-02, -2.3584e-01,  7.4646e-02,  1.2801e-02,  1.6837e-01,
          7.0080e-02, -2.3800e-01, -6.8513e-02,  1.1327e-01, -4.2955e-02,
          7.4137e-02,  3.8332e-01,  5.2349e-01,  4.7903e-02, -1.0139e-01,
         -1.6144e-01,  4.2864e-01,  3.5234e-01, -2.1998e-01,  2.6671e-01,
     

input_to_regressor tensor([[-1.6858e-01,  4.1498e-01, -2.3135e-01, -6.3239e-02,  2.4099e-01,
          1.8732e-01, -3.2682e-02, -1.3457e-01,  2.8803e-02, -2.4340e-02,
          5.6176e-02,  1.0544e-01, -1.5432e-01,  2.9342e-02, -1.4143e-01,
         -1.4329e-01, -3.1692e-01,  4.2894e-02, -3.7061e-01,  3.2738e-01,
         -1.0981e-01,  6.2471e-02,  2.1842e-01, -1.9163e-01, -2.1029e-02,
         -1.1033e-01, -3.0762e-01,  1.7579e-01,  1.8806e-01, -1.1739e-01,
          1.9182e-01, -7.7178e-02, -2.2846e-02,  1.8118e-01,  2.1219e-01,
          2.1616e-01,  4.6874e-02, -4.4633e-02, -4.5916e-02, -2.7995e-01,
          1.9852e-01,  9.9836e-02,  1.7005e-01, -2.2002e-02, -8.4338e-02,
         -1.8644e-01,  2.2626e-01,  2.4620e-01, -2.2769e-01,  6.3693e-02,
          1.8656e-02, -1.9936e-01,  4.3474e-02, -4.1708e-01,  1.7115e-01,
          1.6741e-01, -9.0974e-02, -2.6757e-01, -1.9137e-01,  3.3474e-02,
         -9.1292e-02, -9.9948e-02,  5.2810e-02, -1.6657e-01,  3.9808e-01,
          2.4020e-0

input_to_regressor tensor([[-6.5351e-02,  3.4675e-01, -3.5735e-01, -2.9701e-01, -1.0547e-02,
          1.1244e-01, -1.3076e-01,  1.7232e-02, -5.9326e-02,  3.8768e-01,
         -1.8627e-01,  1.2782e-01, -3.5882e-01,  3.0978e-01, -2.8120e-01,
         -6.1403e-03,  1.9576e-02,  2.4970e-01, -3.2864e-02,  2.5397e-01,
         -2.5679e-01,  2.5882e-01,  2.3661e-01, -5.5263e-01, -1.3459e-02,
         -6.1374e-02, -1.5216e-01,  1.1753e-01,  4.6663e-02,  3.6620e-02,
         -1.7960e-01, -1.6860e-01, -5.6345e-02, -2.0733e-01,  5.0053e-02,
          7.5759e-02,  2.4326e-01, -1.3619e-01,  1.3038e-01, -1.8613e-01,
          9.1404e-02,  3.5096e-01,  2.2836e-01, -1.6201e-01,  1.1200e-01,
         -1.7841e-01, -5.2469e-02,  2.5227e-01, -9.9118e-02, -6.4385e-03,
          2.9698e-01, -2.4851e-02, -2.8985e-02, -4.5037e-01,  2.1334e-01,
          7.2288e-02,  9.9017e-03, -2.0690e-01,  2.5860e-02,  7.6874e-02,
          1.3253e-01,  7.5120e-02,  1.5016e-01,  9.7069e-02,  1.7547e-01,
         -2.0864e-0

logits tensor([[ 0.1294, -1.8128, -0.1654,  ..., -1.4135,  0.1382,  0.7863],
        [-0.2966, -0.8822, -0.8697,  ...,  0.3624,  0.5809,  1.6490]],
       device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>) torch.Size([2, 768])
smile_embeddings tensor([[ 0.0311,  0.3709, -0.3466, -0.0709,  0.2325,  0.0748, -0.1925, -0.0821,
         -0.1436,  0.2980, -0.1721,  0.1147,  0.0155,  0.2066, -0.3707,  0.1030,
         -0.1374,  0.1986,  0.1798,  0.3501,  0.0915,  0.2380,  0.1927, -0.3504,
          0.2050,  0.0244, -0.0863,  0.1959,  0.2891,  0.0045, -0.0832, -0.1685,
         -0.0497, -0.0979,  0.1125,  0.0851,  0.1625, -0.1065,  0.1747, -0.1189,
          0.1371,  0.2462,  0.3521, -0.0722, -0.0313, -0.0430,  0.0386,  0.1688,
         -0.0173,  0.0799,  0.3104,  0.1226,  0.0053, -0.1376,  0.3580, -0.1677,
         -0.0995, -0.2671,  0.0887, -0.1722,  0.1929,  0.1674, -0.0890,  0.0201,
          0.1375, -0.2258,  0.1732, -0.1243,  0.0260, -0.2123,  0.1808, -0.1802,
         -0.2963

input_to_regressor tensor([[-1.5126e-02,  6.7929e-02, -1.9064e-01, -2.6591e-01,  1.2518e-01,
          1.2589e-02,  8.7844e-02,  9.0850e-02,  8.6925e-02,  1.0864e-01,
         -7.6584e-02,  1.3774e-01, -2.0872e-01,  2.2937e-01, -2.7753e-01,
         -8.8909e-02, -2.4925e-01,  2.3380e-01, -2.8668e-03,  2.0015e-01,
         -2.8824e-01, -4.8167e-03,  2.4144e-01, -4.4756e-01, -1.3974e-02,
         -9.8437e-02, -1.6282e-01,  2.0003e-01,  8.1000e-02, -1.1146e-01,
         -2.5301e-01, -5.6321e-02, -1.2335e-01,  1.0187e-01, -9.1028e-03,
          1.6876e-01,  2.5123e-01, -2.1092e-01,  1.0327e-01, -1.6855e-01,
          2.6302e-01,  3.0781e-01,  2.0766e-01,  2.9161e-02,  5.2273e-02,
         -2.7710e-01, -3.1099e-02,  2.6105e-01, -9.8973e-02, -9.3205e-02,
          6.2974e-02,  1.3919e-03,  7.1571e-02, -3.4493e-01,  2.6283e-01,
          2.6178e-02, -1.5130e-01,  1.4084e-03, -2.5369e-01,  3.0817e-01,
          2.1329e-01, -2.9890e-02,  1.7134e-01, -4.5607e-02,  1.6770e-01,
         -1.5712e-0

logits tensor([[-0.2188, -0.9537, -0.9130,  ...,  0.4257,  0.5275,  1.5223],
        [-0.2040, -1.0950, -0.1784,  ..., -0.4692,  0.6279,  1.3970]],
       device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>) torch.Size([2, 768])
smile_embeddings tensor([[-0.0577,  0.4572, -0.2182,  0.0837,  0.1934,  0.0388, -0.2042, -0.0036,
          0.1391,  0.1218, -0.1852,  0.0099, -0.2032,  0.1800, -0.0721,  0.1437,
         -0.2620,  0.1514, -0.3787,  0.1884,  0.0110,  0.2853,  0.1125, -0.2096,
          0.1179,  0.0796,  0.0488,  0.2906,  0.1219,  0.0988, -0.0403, -0.3641,
          0.1708, -0.1495,  0.2131,  0.0178, -0.2347,  0.1287, -0.0403,  0.0213,
          0.0612,  0.2806,  0.5833, -0.1132, -0.2131, -0.2689,  0.2584,  0.3937,
         -0.1174,  0.2708,  0.2151,  0.0015, -0.0039,  0.0844,  0.2804,  0.0749,
         -0.3144, -0.1222, -0.2922,  0.0666,  0.0094,  0.1286,  0.1950, -0.1022,
         -0.0141, -0.0812,  0.2499, -0.2345, -0.1689, -0.2591,  0.0045,  0.0010,
         -0.3740

logits tensor([[ 0.8386, -0.5247,  0.1221,  ..., -0.0886, -0.1060,  2.2836],
        [-0.2040, -1.0950, -0.1784,  ..., -0.4692,  0.6279,  1.3970]],
       device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>) torch.Size([2, 768])
smile_embeddings tensor([[-0.1372, -0.0703, -0.2009, -0.0434,  0.0794,  0.0512, -0.0670,  0.0883,
          0.2536,  0.0280, -0.0255,  0.0933, -0.2600,  0.1570, -0.2262, -0.2950,
         -0.2249,  0.1162, -0.1876,  0.2221, -0.0764,  0.0761,  0.3339, -0.3382,
         -0.0281, -0.1378,  0.0532,  0.1046,  0.0719, -0.0909,  0.0595, -0.2521,
          0.0694,  0.2108,  0.1712,  0.0420, -0.0259, -0.2303,  0.3036, -0.1253,
          0.1989,  0.1826,  0.5572,  0.0210, -0.1948, -0.4295,  0.2754,  0.5009,
          0.3250,  0.0378,  0.3485,  0.0713, -0.0678, -0.1668, -0.1107,  0.1890,
         -0.2662,  0.0706, -0.3338,  0.1358, -0.0247,  0.2402,  0.1878, -0.0498,
          0.2421, -0.0924,  0.2879, -0.3095, -0.0096, -0.1684, -0.5302, -0.0204,
         -0.3695

logits tensor([[ 0.5599, -1.4345, -0.7970,  ...,  0.1868,  0.5402,  1.2432],
        [ 0.5599, -1.4345, -0.7970,  ...,  0.1868,  0.5402,  1.2432]],
       device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>) torch.Size([2, 768])
smile_embeddings tensor([[-0.1121,  0.4334, -0.2715, -0.0565,  0.2927,  0.2710, -0.0579, -0.0887,
         -0.0390,  0.1480, -0.2246, -0.1298, -0.0812, -0.0125, -0.0601, -0.1052,
         -0.3587,  0.2320, -0.2113,  0.2643, -0.0781, -0.0390,  0.1210, -0.3963,
          0.1113, -0.1916, -0.0849,  0.2657,  0.2049,  0.0013,  0.0807, -0.0277,
         -0.0224,  0.0310,  0.3090, -0.0441, -0.1840, -0.0488,  0.0918, -0.1912,
         -0.0318, -0.0308,  0.2904, -0.0939, -0.1424, -0.1664,  0.1183,  0.4159,
         -0.2503,  0.1379,  0.1252,  0.0564,  0.0350, -0.2040,  0.2809, -0.0029,
         -0.2201, -0.1201, -0.0666,  0.0419,  0.0251, -0.1479,  0.1345,  0.1255,
          0.4066,  0.1490,  0.3220, -0.1881, -0.1007, -0.2584, -0.1237, -0.0852,
         -0.3091

logits tensor([[-0.1431, -1.2431, -0.0626,  ..., -0.4679,  0.7815,  1.0820],
        [ 0.6032, -2.1316, -0.1725,  ..., -1.5678, -0.5253, -0.3527]],
       device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>) torch.Size([2, 768])
smile_embeddings tensor([[-0.0833,  0.3948, -0.3736,  0.2338,  0.3196, -0.0306,  0.0434, -0.0650,
          0.0409,  0.3646, -0.0982,  0.1312,  0.0992,  0.0105, -0.3095, -0.1212,
         -0.2578,  0.1431,  0.2332,  0.4377,  0.2498,  0.3840,  0.1827, -0.3055,
          0.1025, -0.1193, -0.0752,  0.1279,  0.0711, -0.0617, -0.2385, -0.1417,
          0.0406, -0.0726,  0.2114, -0.0319, -0.0811, -0.1877, -0.0023, -0.0962,
          0.0192,  0.2721,  0.5142,  0.1490, -0.1075,  0.0006,  0.2310,  0.1830,
         -0.1588,  0.2315,  0.3322,  0.0909,  0.2301, -0.0657,  0.1976,  0.0039,
          0.0294, -0.0861, -0.0514, -0.0200, -0.0326,  0.0694,  0.0072,  0.2374,
          0.1972, -0.1929,  0.3927, -0.1567, -0.0341, -0.1696,  0.0759, -0.0132,
         -0.3467

smile_embeddings tensor([[ 2.6002e-02,  4.5570e-01, -5.8692e-01,  4.0813e-02,  1.7496e-01,
         -2.1518e-01,  1.1251e-01,  1.2969e-01,  1.5744e-02,  1.8160e-01,
         -2.0919e-01,  1.1862e-01, -7.6702e-02,  1.6912e-01, -1.7983e-01,
          4.3971e-02, -2.9169e-01,  1.4062e-01,  1.0547e-01,  4.3999e-01,
          4.2932e-02,  3.6294e-01,  2.6456e-01, -3.5792e-01,  1.7998e-01,
         -1.4296e-01, -4.6733e-02,  2.6486e-01,  1.6566e-01,  7.8144e-02,
         -1.3141e-01, -1.5422e-01, -6.8508e-02, -4.7741e-02,  3.3523e-01,
          9.2458e-02,  7.7086e-04, -5.7631e-02,  6.8835e-02, -2.2368e-01,
          8.5614e-02,  3.5753e-01,  5.1055e-01,  5.5445e-02, -3.9872e-02,
         -1.5967e-02,  1.6530e-01,  2.1348e-01, -3.5768e-01,  2.8686e-01,
          3.3427e-01,  8.0899e-02,  2.2472e-01, -1.8014e-02,  2.7286e-01,
         -1.4083e-01,  6.3908e-03,  1.5749e-02, -9.4619e-02, -2.0971e-01,
          2.0524e-02,  1.0159e-01,  9.1283e-02,  9.1260e-02,  1.0813e-01,
         -2.8315e-01,

logits tensor([[-0.1431, -1.2431, -0.0626,  ..., -0.4679,  0.7815,  1.0820],
        [-0.2188, -0.9537, -0.9130,  ...,  0.4257,  0.5275,  1.5223]],
       device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>) torch.Size([2, 768])
smile_embeddings tensor([[-3.7268e-02,  3.5143e-01, -4.0595e-01, -9.4787e-03,  2.0112e-01,
         -1.5444e-01, -8.2956e-03,  4.4189e-02, -4.1152e-02,  2.9178e-01,
         -2.4046e-01,  7.0410e-02,  2.5862e-02,  9.2725e-02, -2.9567e-01,
          4.0094e-02, -2.0333e-01,  1.2849e-01,  1.7886e-01,  4.3197e-01,
          2.0758e-01,  5.8328e-01,  1.4372e-01, -3.4904e-01,  1.1648e-01,
         -1.0673e-01, -5.8010e-02,  2.7180e-01,  1.6949e-01,  1.6156e-01,
         -1.7179e-01, -6.9630e-02, -1.1749e-01, -6.2877e-02,  1.5566e-01,
          1.5643e-01, -6.9837e-02,  1.1018e-02,  1.6497e-01,  7.0118e-02,
          1.0436e-01,  2.3805e-01,  4.5788e-01,  3.9134e-02, -1.5392e-01,
         -3.0930e-02,  1.8356e-01,  2.5505e-01, -1.7167e-01,  3.4329e-01,
     

input_to_regressor tensor([[-1.1349e-01,  9.5216e-02, -2.5067e-01, -4.3327e-01, -1.2894e-02,
          1.7923e-01, -1.8634e-01, -8.6593e-02, -4.5398e-02,  5.5300e-02,
         -2.0621e-02,  1.3855e-01, -2.8289e-01,  6.3326e-02, -1.4678e-01,
          4.6312e-02, -1.5076e-01,  4.0434e-02, -2.1413e-02,  3.6053e-01,
         -2.5494e-01, -1.0200e-01,  2.5409e-01, -2.4784e-01,  1.5660e-01,
          6.6845e-02, -8.2213e-02,  4.6166e-02, -7.8558e-02, -6.7083e-02,
          1.1183e-01, -6.6138e-02, -2.0872e-01,  5.4761e-02, -7.6379e-02,
          6.5738e-02,  4.2392e-01, -1.3725e-01,  4.3868e-02, -1.8316e-01,
          1.0196e-01,  3.3062e-02,  2.7255e-02, -1.9839e-01,  1.2109e-01,
         -1.0654e-01, -7.1707e-02,  1.8913e-01,  2.8576e-02, -4.3705e-02,
          1.7067e-01, -1.0800e-01, -1.4586e-02, -2.5904e-01,  2.7868e-01,
          1.2420e-01, -2.3754e-01, -3.6685e-01,  1.3897e-01,  1.4203e-01,
          2.9937e-01, -9.8130e-03, -1.4064e-01, -3.5396e-01,  2.8241e-01,
          9.7810e-0

smile_embeddings tensor([[-4.7985e-02,  2.8997e-01, -3.8804e-01,  2.0606e-01,  1.3515e-01,
         -1.1381e-01, -1.2048e-01, -6.0610e-03, -4.9587e-02,  1.3980e-01,
         -2.1809e-01,  1.0955e-01, -1.5646e-02,  1.1312e-01, -2.4220e-01,
          3.5192e-02, -2.5896e-01,  2.0560e-01,  8.0952e-02,  3.7234e-01,
          1.9976e-01,  3.7301e-01,  2.5382e-01, -3.6032e-01,  1.1906e-01,
         -6.4650e-02, -1.0603e-01,  2.7404e-01,  5.6358e-02,  1.1093e-01,
         -8.9265e-02, -1.2354e-01,  3.7774e-02, -8.9006e-02,  3.7899e-01,
          1.3365e-01, -1.1959e-01, -4.0826e-03,  5.2195e-02, -2.4633e-02,
          2.2358e-01,  3.1568e-01,  4.8479e-01,  1.2115e-01, -1.1471e-01,
         -7.5406e-02,  2.5085e-01,  2.9834e-01, -2.0981e-01,  2.8312e-01,
          2.2861e-01,  1.0219e-01,  5.0753e-02,  8.8429e-02,  3.0091e-01,
          9.7293e-02,  3.5154e-02, -4.8129e-02, -2.7165e-02, -6.8123e-02,
         -3.9976e-02,  1.7783e-01,  9.8053e-02, -5.2546e-02,  1.2867e-01,
         -1.2322e-01,

logits tensor([[ 0.8386, -0.5247,  0.1221,  ..., -0.0886, -0.1060,  2.2836],
        [-0.2966, -0.8822, -0.8697,  ...,  0.3624,  0.5809,  1.6490]],
       device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>) torch.Size([2, 768])
smile_embeddings tensor([[-0.1425, -0.0168, -0.1144, -0.1211, -0.0216,  0.0854, -0.0631, -0.0647,
          0.1680, -0.0527, -0.0550,  0.0023, -0.2448,  0.1323, -0.2078, -0.2513,
         -0.2600,  0.1919, -0.0906,  0.3205, -0.0446,  0.1512,  0.3484, -0.3593,
         -0.0053, -0.0880, -0.0034,  0.0471,  0.1062,  0.0647,  0.1003, -0.1662,
          0.1835,  0.2446,  0.1562,  0.0979, -0.1277, -0.1084,  0.2504, -0.0918,
          0.2982,  0.1960,  0.3618,  0.0362, -0.1243, -0.4399,  0.2573,  0.3990,
          0.2662,  0.0485,  0.2297, -0.0400, -0.1349, -0.2962,  0.0479,  0.2586,
         -0.2093, -0.1266, -0.1831,  0.3116, -0.0575,  0.2781,  0.2515,  0.0608,
          0.2055,  0.1133,  0.2411, -0.1764,  0.0767, -0.1386, -0.3741,  0.0262,
         -0.2066

logits tensor([[ 0.6298, -1.8201, -0.5266,  ..., -0.3707,  0.4771,  1.4649],
        [-0.0710, -1.4718, -0.0976,  ..., -0.6046, -0.0276,  0.0330]],
       device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>) torch.Size([2, 768])
smile_embeddings tensor([[-0.1143,  0.2209, -0.1406, -0.3501,  0.0498,  0.1478, -0.2099, -0.1470,
         -0.0771,  0.2537,  0.0482,  0.0305, -0.2971,  0.0978, -0.3734, -0.0664,
         -0.3776, -0.0152, -0.2942,  0.3119, -0.3962, -0.0392,  0.2307, -0.2858,
          0.1960, -0.0603, -0.0552,  0.1560, -0.0172, -0.0850,  0.2081, -0.1214,
          0.1001,  0.1757,  0.0973,  0.0676, -0.0736, -0.0756,  0.0474, -0.1067,
          0.0814,  0.0372,  0.1826,  0.0949,  0.0306, -0.3274, -0.1299,  0.3422,
         -0.0432, -0.1420,  0.0602, -0.1181, -0.1305, -0.3835,  0.1317,  0.1820,
         -0.2351, -0.0829, -0.2623,  0.1862,  0.2624, -0.1170,  0.0678, -0.0972,
          0.3882,  0.2197,  0.2162, -0.2299, -0.2475, -0.3912, -0.1277, -0.2720,
         -0.1033

In [17]:
#test(model, loss_fn, train_dataloader, test_dataloader, scaler_target, device)