In [None]:
import datamol as dm
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, EarlyStoppingCallback, get_linear_schedule_with_warmup
import torch
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
import matplotlib.pyplot as plt

# Load the ChemBERTa model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
model = AutoModelForSequenceClassification.from_pretrained("seyonec/ChemBERTa-zinc-base-v1", num_labels=1)

# Load the data
df = dm.data.freesolv()
X, y = df["smiles"], df["expt"]

# Tokenize the SMILES strings
def tokenize_function(smiles):
    return tokenizer(smiles, padding="max_length", truncation=True, max_length=128)

# Apply the tokenization
X_tokenized = X.apply(tokenize_function)

# Convert the tokenized data to the format required by Hugging Face
class FreeSolvDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx], dtype=torch.float)
        return item

    def __len__(self):
        return len(self.labels)

# Split the data into training, validation, and test sets
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.4, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)

# Tokenize the datasets
train_encodings = tokenizer(list(X_train), padding=True, truncation=True, max_length=128, return_tensors="pt")
val_encodings = tokenizer(list(X_val), padding=True, truncation=True, max_length=128, return_tensors="pt")
test_encodings = tokenizer(list(X_test), padding=True, truncation=True, max_length=128, return_tensors="pt")

# Create the datasets
train_dataset = FreeSolvDataset(train_encodings, y_train.values)
val_dataset = FreeSolvDataset(val_encodings, y_val.values)
test_dataset = FreeSolvDataset(test_encodings, y_test.values)

# Define training arguments with early stopping and save the best model
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=1e-5,
    num_train_epochs=100,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
    save_total_limit=1,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
)

# Define a custom compute_metrics function for regression
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = predictions.squeeze()
    mse = mean_squared_error(labels, predictions)
    r2 = r2_score(labels, predictions)
    return {"mse": mse, "r2": r2}

# Create the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)

# Create the optimizer
optimizer = trainer.create_optimizer()

# Add a linear schedule with warmup
num_training_steps = len(train_dataset) // training_args.per_device_train_batch_size * training_args.num_train_epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)

# Assign the scheduler to the trainer
trainer.lr_scheduler = scheduler

# Train the model
trainer.train()

# Evaluate the model on the test set
test_results = trainer.evaluate(test_dataset)
print(test_results)

# Predict on the test set
predictions = trainer.predict(test_dataset).predictions.squeeze()

# Plot predictions vs true values for test set
plt.figure(figsize=(8, 6))
plt.scatter(y_test, predictions, alpha=0.6, color='green')
plt.xlabel('True Values')
plt.ylabel('Predictions')
plt.title('Test Set: True Values vs Predictions')
plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'k--', lw=2)
plt.show()


In [None]:
!pip install rdkit

In [None]:
#Experimental Class for Smiles Enumeration, Iterator and SmilesIterator adapted from Keras 1.2.2
from rdkit import Chem
import numpy as np
import threading

class Iterator(object):
    """Abstract base class for data iterators.

    # Arguments
        n: Integer, total number of samples in the dataset to loop over.
        batch_size: Integer, size of a batch.
        shuffle: Boolean, whether to shuffle the data between epochs.
        seed: Random seeding for data shuffling.
    """

    def __init__(self, n, batch_size, shuffle, seed):
        self.n = n
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.batch_index = 0
        self.total_batches_seen = 0
        self.lock = threading.Lock()
        self.index_generator = self._flow_index(n, batch_size, shuffle, seed)
        if n < batch_size:
            raise ValueError('Input data length is shorter than batch_size\nAdjust batch_size')

    def reset(self):
        self.batch_index = 0

    def _flow_index(self, n, batch_size=32, shuffle=False, seed=None):
        # Ensure self.batch_index is 0.
        self.reset()
        while 1:
            if seed is not None:
                np.random.seed(seed + self.total_batches_seen)
            if self.batch_index == 0:
                index_array = np.arange(n)
                if shuffle:
                    index_array = np.random.permutation(n)

            current_index = (self.batch_index * batch_size) % n
            if n > current_index + batch_size:
                current_batch_size = batch_size
                self.batch_index += 1
            else:
                current_batch_size = n - current_index
                self.batch_index = 0
            self.total_batches_seen += 1
            yield (index_array[current_index: current_index + current_batch_size],
                   current_index, current_batch_size)

    def __iter__(self):
        # Needed if we want to do something like:
        # for x, y in data_gen.flow(...):
        return self

    def __next__(self, *args, **kwargs):
        return self.next(*args, **kwargs)




class SmilesIterator(Iterator):
    """Iterator yielding data from a SMILES array.

    # Arguments
        x: Numpy array of SMILES input data.
        y: Numpy array of targets data.
        smiles_data_generator: Instance of `SmilesEnumerator`
            to use for random SMILES generation.
        batch_size: Integer, size of a batch.
        shuffle: Boolean, whether to shuffle the data between epochs.
        seed: Random seed for data shuffling.
        dtype: dtype to use for returned batch. Set to keras.backend.floatx if using Keras
    """

    def __init__(self, x, y, smiles_data_generator,
                 batch_size=32, shuffle=False, seed=None,
                 dtype=np.float32
                 ):
        if y is not None and len(x) != len(y):
            raise ValueError('X (images tensor) and y (labels) '
                             'should have the same length. '
                             'Found: X.shape = %s, y.shape = %s' %
                             (np.asarray(x).shape, np.asarray(y).shape))

        self.x = np.asarray(x)

        if y is not None:
            self.y = np.asarray(y)
        else:
            self.y = None
        self.smiles_data_generator = smiles_data_generator
        self.dtype = dtype
        super(SmilesIterator, self).__init__(x.shape[0], batch_size, shuffle, seed)

    def next(self):
        """For python 2.x.

        # Returns
            The next batch.
        """
        # Keeps under lock only the mechanism which advances
        # the indexing of each batch.
        with self.lock:
            index_array, current_index, current_batch_size = next(self.index_generator)
        # The transformation of images is not under thread lock
        # so it can be done in parallel
        batch_x = np.zeros(tuple([current_batch_size] + [ self.smiles_data_generator.pad, self.smiles_data_generator._charlen]), dtype=self.dtype)
        for i, j in enumerate(index_array):
            smiles = self.x[j:j+1]
            x = self.smiles_data_generator.transform(smiles)
            batch_x[i] = x

        if self.y is None:
            return batch_x
        batch_y = self.y[index_array]
        return batch_x, batch_y


class SmilesEnumerator(object):
    """SMILES Enumerator, vectorizer and devectorizer
    
    #Arguments
        charset: string containing the characters for the vectorization
          can also be generated via the .fit() method
        pad: Length of the vectorization
        leftpad: Add spaces to the left of the SMILES
        isomericSmiles: Generate SMILES containing information about stereogenic centers
        enum: Enumerate the SMILES during transform
        canonical: use canonical SMILES during transform (overrides enum)
    """
    def __init__(self, charset = '@C)(=cOn1S2/H[N]\\', pad=120, leftpad=True, isomericSmiles=True, enum=True, canonical=False):
        self._charset = None
        self.charset = charset
        self.pad = pad
        self.leftpad = leftpad
        self.isomericSmiles = isomericSmiles
        self.enumerate = enum
        self.canonical = canonical

    @property
    def charset(self):
        return self._charset
        
    @charset.setter
    def charset(self, charset):
        self._charset = charset
        self._charlen = len(charset)
        self._char_to_int = dict((c,i) for i,c in enumerate(charset))
        self._int_to_char = dict((i,c) for i,c in enumerate(charset))
        
    def fit(self, smiles, extra_chars=[], extra_pad = 5):
        """Performs extraction of the charset and length of a SMILES datasets and sets self.pad and self.charset
        
        #Arguments
            smiles: Numpy array or Pandas series containing smiles as strings
            extra_chars: List of extra chars to add to the charset (e.g. "\\\\" when "/" is present)
            extra_pad: Extra padding to add before or after the SMILES vectorization
        """
        charset = set("".join(list(smiles)))
        self.charset = "".join(charset.union(set(extra_chars)))
        self.pad = max([len(smile) for smile in smiles]) + extra_pad
        
    def randomize_smiles(self, smiles):
        """Perform a randomization of a SMILES string
        must be RDKit sanitizable"""
        m = Chem.MolFromSmiles(smiles)
        ans = list(range(m.GetNumAtoms()))
        np.random.shuffle(ans)
        nm = Chem.RenumberAtoms(m,ans)
        return Chem.MolToSmiles(nm, canonical=self.canonical, isomericSmiles=self.isomericSmiles)

    def transform(self, smiles):
        """Perform an enumeration (randomization) and vectorization of a Numpy array of smiles strings
        #Arguments
            smiles: Numpy array or Pandas series containing smiles as strings
        """
        one_hot =  np.zeros((smiles.shape[0], self.pad, self._charlen),dtype=np.int8)
        
        if self.leftpad:
            for i,ss in enumerate(smiles):
                if self.enumerate: ss = self.randomize_smiles(ss)
                l = len(ss)
                diff = self.pad - l
                for j,c in enumerate(ss):
                    one_hot[i,j+diff,self._char_to_int[c]] = 1
            return one_hot
        else:
            for i,ss in enumerate(smiles):
                if self.enumerate: ss = self.randomize_smiles(ss)
                for j,c in enumerate(ss):
                    one_hot[i,j,self._char_to_int[c]] = 1
            return one_hot

      
    def reverse_transform(self, vect):
        """ Performs a conversion of a vectorized SMILES to a smiles strings
        charset must be the same as used for vectorization.
        #Arguments
            vect: Numpy array of vectorized SMILES.
        """       
        smiles = []
        for v in vect:
            #mask v 
            v=v[v.sum(axis=1)==1]
            #Find one hot encoded index with argmax, translate to char and join to string
            smile = "".join(self._int_to_char[i] for i in v.argmax(axis=1))
            smiles.append(smile)
        return np.array(smiles)
     
if __name__ == "__main__":
    smiles = np.array([ "CCC(=O)O[C@@]1(CC[NH+](C[C@H]1CC=C)C)c2ccccc2",
                        "CCC[S@@](=O)c1ccc2c(c1)[nH]/c(=N/C(=O)OC)/[nH]2"]*10
                        )
    #Test canonical SMILES vectorization
    sm_en = SmilesEnumerator(canonical=True, enum=False)
    sm_en.fit(smiles, extra_chars=["\\"])
    v = sm_en.transform(smiles)
    transformed = sm_en.reverse_transform(v)
    if len(set(transformed)) > 2: print("Too many different canonical SMILES generated")
    
    #Test enumeration 
    sm_en.canonical = False
    sm_en.enumerate = True
    v2 = sm_en.transform(smiles)
    transformed = sm_en.reverse_transform(v2)
    if len(set(transformed)) < 3: print("Too few enumerated SMILES generated")

    #Reconstruction
    reconstructed = sm_en.reverse_transform(v[0:5])
    for i, smile in enumerate(reconstructed):
        if smile != smiles[i]:
            print("Error in reconstruction %s %s"%(smile, smiles[i]))
            break
    
    #test Pandas
    import pandas as pd
    df = pd.DataFrame(smiles)
    v = sm_en.transform(df[0])
    if v.shape != (20, 52, 18): print("Possible error in pandas use")
    
    #BUG, when batchsize > x.shape[0], then it only returns x.shape[0]!
    #Test batch generation
    sm_it = SmilesIterator(smiles, np.array([1,2]*10), sm_en, batch_size=10, shuffle=True)
    X, y = sm_it.next()
    if sum(y==1) - sum(y==2) > 1:
        print("Unbalanced generation of batches")
    if len(X) != 10: print("Error in batchsize generation")



        

In [None]:
sme = SmilesEnumerator()
print(help(SmilesEnumerator))

In [None]:
for i in range(10):
    print(sme.randomize_smiles("CCC(=O)O[C@@]1(CC[NH+](C[C@H]1CC=C)C)c2ccccc2"))

In [None]:
df = dm.data.freesolv()
X, y = df["smiles"], df["expt"]


In [None]:
sme.randomize_smiles(X[0])

In [None]:
new_X, new_y = [],[]
for i in range(len(X)):
    for j in range(100):
        new_X.append(sme.randomize_smiles(X[i]))
        new_y.append(y[i])

In [None]:
import datamol as dm
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, EarlyStoppingCallback, get_linear_schedule_with_warmup
import torch
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
import matplotlib.pyplot as plt

# Load the ChemBERTa model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
model = AutoModelForSequenceClassification.from_pretrained("seyonec/ChemBERTa-zinc-base-v1", num_labels=1)

from copy import deepcopy

# Load the data
df = dm.data.freesolv()
X, y = df["smiles"], df["expt"]

# X,y = deepcopy(new_X),deepcopy(new_y)
# X = pd.DataFrame(X).squeeze()
# y = pd.DataFrame(y).squeeze()

# Tokenize the SMILES strings
def tokenize_function(smiles):
    return tokenizer(smiles, padding="max_length", truncation=True, max_length=128)



# Convert the tokenized data to the format required by Hugging Face
class FreeSolvDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx], dtype=torch.float)
        return item

    def __len__(self):
        return len(self.labels)

# Split the data into training, validation, and test sets
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.4, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)

new_X, new_y = [],[]
listx = X_train.tolist()
listy = y_train.tolist() 
for i in range(len(X_train)):
    for j in range(100):
        new_X.append(sme.randomize_smiles(listx[i]))
        new_y.append(listy[i])
        
# X,y = deepcopy(new_X),deepcopy(new_y)
X_train = pd.DataFrame(new_X).squeeze()
y_train = pd.DataFrame(new_y).squeeze()

# Apply the tokenization
# Tokenize the datasets and assign the results back to the variables
# X_train = X_train.apply(tokenize_function)
# X_val = X_val.apply(tokenize_function)
# X_test = X_test.apply(tokenize_function)
        

# Tokenize the datasets
train_encodings = tokenizer(list(X_train), padding=True, truncation=True, max_length=128, return_tensors="pt")
val_encodings = tokenizer(list(X_val), padding=True, truncation=True, max_length=128, return_tensors="pt")
test_encodings = tokenizer(list(X_test), padding=True, truncation=True, max_length=128, return_tensors="pt")

# Create the datasets
train_dataset = FreeSolvDataset(train_encodings, y_train.values)
val_dataset = FreeSolvDataset(val_encodings, y_val.values)
test_dataset = FreeSolvDataset(test_encodings, y_test.values)

# Define training arguments with early stopping and save the best model
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=1e-5,
    num_train_epochs=100,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
    save_total_limit=1,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
)

# Define a custom compute_metrics function for regression
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = predictions.squeeze()
    mse = mean_squared_error(labels, predictions)
    r2 = r2_score(labels, predictions)
    return {"mse": mse, "r2": r2}

# Create the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)

# Create the optimizer
optimizer = trainer.create_optimizer()

# Add a linear schedule with warmup
num_training_steps = len(train_dataset) // training_args.per_device_train_batch_size * training_args.num_train_epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)

# Assign the scheduler to the trainer
trainer.lr_scheduler = scheduler

# Train the model
trainer.train()

# Evaluate the model on the test set
test_results = trainer.evaluate(test_dataset)
print(test_results)

# Predict on the test set
predictions = trainer.predict(test_dataset).predictions.squeeze()

# Plot predictions vs true values for test set
plt.figure(figsize=(8, 6))
plt.scatter(y_test, predictions, alpha=0.6, color='green')
plt.xlabel('True Values')
plt.ylabel('Predictions')
plt.title('Test Set: True Values vs Predictions')
plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'k--', lw=2)
plt.show()


Check if anything went wrong here then add SWA

In [None]:
import datamol as dm
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, EarlyStoppingCallback, get_linear_schedule_with_warmup
import torch
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
import matplotlib.pyplot as plt
from copy import deepcopy

# Load the ChemBERTa model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
model = AutoModelForSequenceClassification.from_pretrained("seyonec/ChemBERTa-zinc-base-v1", num_labels=1)

# Load the data
df = dm.data.freesolv()
X, y = df["smiles"], df["expt"]

# Tokenize the SMILES strings
def tokenize_function(smiles):
    return tokenizer(smiles, padding="max_length", truncation=True, max_length=128)

# Convert the tokenized data to the format required by Hugging Face
class FreeSolvDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx], dtype=torch.float)
        return item

    def __len__(self):
        return len(self.labels)

# Split the data into training, validation, and test sets
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.4, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)

# Data augmentation: Randomize SMILES strings
new_X, new_y = [], []
listx = X_train.tolist()
listy = y_train.tolist()

for i in range(len(X_train)):
    for j in range(100):
        new_X.append(dm.randomize_smiles(listx[i]))  # Assume dm.randomize_smiles is a valid function
        new_y.append(listy[i])

X_train = pd.DataFrame(new_X).squeeze()
y_train = pd.DataFrame(new_y).squeeze()

# Apply the tokenization and assign the results back to the variables
X_train = X_train.apply(tokenize_function)
X_val = X_val.apply(tokenize_function)
X_test = X_test.apply(tokenize_function)

# Convert the tokenized data to tensors
def convert_to_tensor(tokenized_data):
    return {
        'input_ids': torch.tensor([x['input_ids'] for x in tokenized_data]),
        'attention_mask': torch.tensor([x['attention_mask'] for x in tokenized_data])
    }

train_encodings = convert_to_tensor(X_train)
val_encodings = convert_to_tensor(X_val)
test_encodings = convert_to_tensor(X_test)

# Create the datasets
train_dataset = FreeSolvDataset(train_encodings, y_train.values)
val_dataset = FreeSolvDataset(val_encodings, y_val.values)
test_dataset = FreeSolvDataset(test_encodings, y_test.values)

# Define training arguments with early stopping and save the best model
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=1e-5,
    num_train_epochs=100,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
    save_total_limit=1,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
)

# Define a custom compute_metrics function for regression
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = predictions.squeeze()
    mse = mean_squared_error(labels, predictions)
    r2 = r2_score(labels, predictions)
    return {"mse": mse, "r2": r2}

# Create the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)

# Create the optimizer and scheduler
trainer.create_optimizer_and_scheduler(num_training_steps=len(train_dataset) // training_args.per_device_train_batch_size * training_args.num_train_epochs)

# Train the model
trainer.train()

# Evaluate the model on the test set
test_results = trainer.evaluate(test_dataset)
print(test_results)

# Predict on the test set
predictions = trainer.predict(test_dataset).predictions.squeeze()

# Plot predictions vs true values for test set
plt.figure(figsize=(8, 6))
plt.scatter(y_test, predictions, alpha=0.6, color='green')
plt.xlabel('True Values')
plt.ylabel('Predictions')
plt.title('Test Set: True Values vs Predictions')
plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'k--', lw=2)
plt.show()


In [None]:
import datamol as dm
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, EarlyStoppingCallback
import torch
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
import matplotlib.pyplot as plt

# Load the ChemBERTa model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
model = AutoModelForSequenceClassification.from_pretrained("seyonec/ChemBERTa-zinc-base-v1", num_labels=1)

# Load the data
df = dm.data.freesolv()
X, y = df["smiles"], df["expt"]

# Split the data into training, validation, and test sets
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.4, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)

# Augment the training data with randomized SMILES strings
new_X, new_y = [], []
listx = X_train.tolist()
listy = y_train.tolist() 
for i in range(len(X_train)):
    for j in range(100):
        new_X.append(sme.randomize_smiles(listx[i]))
        new_y.append(listy[i])

# Convert augmented data to DataFrame and Series
X_train = pd.DataFrame(new_X).squeeze()
y_train = pd.DataFrame(new_y).squeeze()

# Tokenize the datasets
train_encodings = tokenizer(list(X_train), padding=True, truncation=True, max_length=128, return_tensors="pt")
val_encodings = tokenizer(list(X_val), padding=True, truncation=True, max_length=128, return_tensors="pt")
test_encodings = tokenizer(list(X_test), padding=True, truncation=True, max_length=128, return_tensors="pt")

# Create the datasets
class FreeSolvDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: val[idx].clone().detach() for key, val in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx], dtype=torch.float).clone().detach()
        return item

    def __len__(self):
        return len(self.labels)

train_dataset = FreeSolvDataset(train_encodings, y_train.values)
val_dataset = FreeSolvDataset(val_encodings, y_val.values)
test_dataset = FreeSolvDataset(test_encodings, y_test.values)

# Define training arguments with early stopping and save the best model
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=3e-5,
    num_train_epochs=1000,
    per_device_train_batch_size=128,
    per_device_eval_batch_size=128,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
    save_total_limit=1,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
)

# Define a custom compute_metrics function for regression
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = predictions.squeeze()
    mse = mean_squared_error(labels, predictions)
    r2 = r2_score(labels, predictions)
    return {"mse": mse, "r2": r2}

# Create the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)

# Create the optimizer
optimizer = trainer.create_optimizer()

# Custom learning rate scheduler
class CustomLRScheduler(torch.optim.lr_scheduler.LambdaLR):
    def __init__(self, optimizer, total_epochs, decay_epochs):
        self.total_epochs = total_epochs
        self.decay_epochs = decay_epochs
        super().__init__(optimizer, self.lr_lambda)

    def lr_lambda(self, epoch):
        if epoch < self.decay_epochs:
            return 1 - (epoch / self.decay_epochs) * 0.5  # Decay by a factor of 2
        else:
            return 0.5  # Remain flat

# Initialize the custom scheduler
scheduler = CustomLRScheduler(optimizer, total_epochs=training_args.num_train_epochs, decay_epochs=10)

# Assign the scheduler to the trainer
trainer.lr_scheduler = scheduler

# Train the model
trainer.train()

# Apply SWA after epoch 10
swa_model = AveragedModel(model)
swa_start = 10
swa_scheduler = SWALR(optimizer, swa_lr=1e-5)

for epoch in range(training_args.num_train_epochs):
    trainer.train()
    if epoch >= swa_start:
        swa_model.update_parameters(model)
        swa_scheduler.step()

# Update the batch norm statistics for SWA model
update_bn(trainer.get_train_dataloader(), swa_model)

# Evaluate the model on the test set
test_results = trainer.evaluate(test_dataset)
print(test_results)

# Predict on the test set
predictions = trainer.predict(test_dataset).predictions.squeeze()

# Plot predictions vs true values for test set
plt.figure(figsize=(8, 6))
plt.scatter(y_test, predictions, alpha=0.6, color='green')
plt.xlabel('True Values')
plt.ylabel('Predictions')
plt.title('Test Set: True Values vs Predictions')
plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'k--', lw=2)
plt.show()


In [None]:
import datamol as dm
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from torch.optim import AdamW
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
import matplotlib.pyplot as plt
from tqdm import tqdm

# Load the ChemBERTa model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
model = AutoModelForSequenceClassification.from_pretrained("seyonec/ChemBERTa-zinc-base-v1", num_labels=1)

# Load the data
print("Loading data...")
df = dm.data.freesolv()
X, y = df["smiles"], df["expt"]

# Split the data into training, validation, and test sets
print("Splitting data into train, validation, and test sets...")
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.4, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)

# Augment the training data with randomized SMILES strings
print("Augmenting training data...")
new_X, new_y = [], []
listx = X_train.tolist()
listy = y_train.tolist()
for i in range(len(X_train)):
    for j in range(100):
        new_X.append(listx[i])
        new_y.append(listy[i])

# Convert augmented data to DataFrame and Series
X_train = pd.DataFrame(new_X).squeeze()
y_train = pd.DataFrame(new_y).squeeze()

# Tokenize the datasets
print("Tokenizing datasets...")
train_encodings = tokenizer(list(X_train), padding=True, truncation=True, max_length=128, return_tensors="pt")
val_encodings = tokenizer(list(X_val), padding=True, truncation=True, max_length=128, return_tensors="pt")
test_encodings = tokenizer(list(X_test), padding=True, truncation=True, max_length=128, return_tensors="pt")

# Create the datasets
class FreeSolvDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: val[idx].clone().detach() for key, val in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx], dtype=torch.float).clone().detach()
        return item

    def __len__(self):
        return len(self.labels)

print("Creating datasets...")
train_dataset = FreeSolvDataset(train_encodings, y_train.values)
val_dataset = FreeSolvDataset(val_encodings, y_val.values)
test_dataset = FreeSolvDataset(test_encodings, y_test.values)

# Create DataLoader for batching
print("Creating DataLoader for batching...")
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128)
test_loader = DataLoader(test_dataset, batch_size=128)

# Create the optimizer and learning rate scheduler
print("Creating optimizer and learning rate scheduler...")
optimizer = AdamW(model.parameters(), lr=3e-5, weight_decay=0.01)

# Custom learning rate scheduler function
def lr_lambda(epoch):
    if epoch < 10:
        return 1 - (epoch / 10) * 0.5  # Decay by a factor of 2
    else:
        return 0.5  # Remain flat

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

# Initialize SWA
print("Initializing SWA...")
swa_model = AveragedModel(model)
swa_start = 10
swa_scheduler = SWALR(optimizer, swa_lr=1e-5)

# Early stopping variables
best_val_loss = float('inf')
best_swa_eval_loss = float('inf')
patience = 5
epochs_no_improve = 0
best_model_state = None

# Training loop
print("Starting training loop...")
for epoch in range(1000):
    model.train()
    total_loss = 0
    for batch in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}"):
        optimizer.zero_grad()
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        scheduler.step()
        total_loss += loss.item()
    
    print(f"Epoch {epoch + 1} - Training Loss: {total_loss / len(train_loader)}")

    model.eval()
    eval_loss = 0
    for batch in tqdm(val_loader, desc=f"Validation Epoch {epoch + 1}"):
        with torch.no_grad():
            outputs = model(**batch)
            loss = outputs.loss
            eval_loss += loss.item()
    
    eval_loss /= len(val_loader)
    print(f"Epoch {epoch + 1} - Validation Loss: {eval_loss}")

    swa_eval_loss = None
    if epoch >= swa_start:
        swa_model.update_parameters(model)
        swa_scheduler.step()
        update_bn(train_loader, swa_model)
        swa_eval_loss = 0
        for batch in tqdm(val_loader, desc=f"SWA Validation Epoch {epoch + 1}"):
            with torch.no_grad():
                outputs = swa_model(**batch)
                loss = outputs.loss
                swa_eval_loss += loss.item()
        
        swa_eval_loss /= len(val_loader)
        print(f"SWA Model Evaluation after epoch {epoch + 1} - Validation Loss: {swa_eval_loss}")

    # Check for early stopping
    if eval_loss < best_val_loss or (swa_eval_loss is not None and swa_eval_loss < best_swa_eval_loss):
        if eval_loss < best_val_loss:
            best_val_loss = eval_loss
            best_model_state = model.state_dict()
        if swa_eval_loss is not None and swa_eval_loss < best_swa_eval_loss:
            best_swa_eval_loss = swa_eval_loss
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1

    if epochs_no_improve >= patience:
        print("Early stopping due to no improvement in validation loss or SWA validation loss.")
        break

# Load the best model state
print("Loading the best model state...")
model.load_state_dict(best_model_state)

# Final SWA model evaluation
print("Evaluating final SWA model...")
update_bn(train_loader, swa_model)
model = swa_model
test_loss = 0
for batch in tqdm(test_loader, desc="Testing"):
    with torch.no_grad():
        outputs = model(**batch)
        loss = outputs.loss
        test_loss += loss.item()

test_loss /= len(test_loader)
print(f"Final Test Loss: {test_loss}")

# Predictions on the test set
print("Generating predictions on the test set...")
model.eval()
predictions = []
for batch in tqdm(test_loader, desc="Predicting"):
    with torch.no_grad():
        outputs = model(**batch)
        preds = outputs.logits.squeeze().tolist()
        predictions.extend(preds)

# Plot predictions vs true values for test set
print("Plotting predictions vs true values...")
plt.figure(figsize=(8, 6))
plt.scatter(y_test, predictions, alpha=0.6, color='green')
plt.xlabel('True Values')
plt.ylabel('Predictions')
plt.title('Test Set: True Values vs Predictions')
plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'k--', lw=2)
plt.show()


In [None]:
import datamol as dm
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from torch.optim import AdamW
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
import matplotlib.pyplot as plt
from tqdm import tqdm

# Load the ChemBERTa model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
model = AutoModelForSequenceClassification.from_pretrained("seyonec/ChemBERTa-zinc-base-v1", num_labels=1)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Load the data
print("Loading data...")
df = dm.data.freesolv()
X, y = df["smiles"], df["expt"]

# Split the data into training, validation, and test sets
print("Splitting data into train, validation, and test sets...")
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.4, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)

# Augment the training data with randomized SMILES strings
print("Augmenting training data...")
new_X, new_y = [], []
listx = X_train.tolist()
listy = y_train.tolist()
for i in range(len(X_train)):
    for j in range(100):
        new_X.append(sme.randomize_smiles(listx[i]))
        new_y.append(listy[i])

# Convert augmented data to DataFrame and Series
X_train = pd.DataFrame(new_X).squeeze()
y_train = pd.DataFrame(new_y).squeeze()

# Tokenize the datasets
print("Tokenizing datasets...")
train_encodings = tokenizer(list(X_train), padding=True, truncation=True, max_length=128, return_tensors="pt")
val_encodings = tokenizer(list(X_val), padding=True, truncation=True, max_length=128, return_tensors="pt")
test_encodings = tokenizer(list(X_test), padding=True, truncation=True, max_length=128, return_tensors="pt")

# Create the datasets
class FreeSolvDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: val[idx].clone().detach() for key, val in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx], dtype=torch.float).clone().detach()
        return item

    def __len__(self):
        return len(self.labels)

print("Creating datasets...")
train_dataset = FreeSolvDataset(train_encodings, y_train.values)
val_dataset = FreeSolvDataset(val_encodings, y_val.values)
test_dataset = FreeSolvDataset(test_encodings, y_test.values)

# Create DataLoader for batching
print("Creating DataLoader for batching...")
train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1024)
test_loader = DataLoader(test_dataset, batch_size=1024)

# Create the optimizer and learning rate scheduler
print("Creating optimizer and learning rate scheduler...")
optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=0.03)

# Custom learning rate scheduler function
def lr_lambda(epoch):
    if epoch < 10:
        return 1 - (epoch / 10) * 0.5  # Decay by a factor of 2
    else:
        return 0.5  # Remain flat

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

# Initialize SWA
print("Initializing SWA...")
swa_model = AveragedModel(model)
swa_start = 10
swa_scheduler = SWALR(optimizer, swa_lr=1e-5)

# Early stopping variables
best_val_loss = float('inf')
best_swa_eval_loss = float('inf')
patience = 5
epochs_no_improve = 0
best_model_state = None
best_swa_model_state = None

# Training loop
print("Starting training loop...")
for epoch in range(1000):
    model.train()
    total_loss = 0
    for batch in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}"):
        optimizer.zero_grad()
        inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
        labels = batch['labels'].to(device)
        outputs = model(**inputs, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        scheduler.step()
        total_loss += loss.item()
    
    print(f"Epoch {epoch + 1} - Training Loss: {total_loss / len(train_loader)}")

    model.eval()
    eval_loss = 0
    for batch in tqdm(val_loader, desc=f"Validation Epoch {epoch + 1}"):
        with torch.no_grad():
            inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
            labels = batch['labels'].to(device)
            outputs = model(**inputs, labels=labels)
            loss = outputs.loss
            eval_loss += loss.item()
    
    eval_loss /= len(val_loader)
    print(f"Epoch {epoch + 1} - Validation Loss: {eval_loss}")

    swa_eval_loss = None
    if epoch >= swa_start:
        swa_model.update_parameters(model)
        swa_scheduler.step()
        update_bn(train_loader, swa_model)
        swa_eval_loss = 0
        for batch in tqdm(val_loader, desc=f"SWA Validation Epoch {epoch + 1}"):
            with torch.no_grad():
                inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
                labels = batch['labels'].to(device)
                outputs = swa_model(**inputs, labels=labels)
                loss = outputs.loss
                swa_eval_loss += loss.item()
        
        swa_eval_loss /= len(val_loader)
        print(f"SWA Model Evaluation after epoch {epoch + 1} - Validation Loss: {swa_eval_loss}")

        # Check for early stopping and best model state after SWA has started
        if swa_eval_loss < best_swa_eval_loss:
            best_swa_eval_loss = swa_eval_loss
            best_swa_model_state = swa_model.state_dict()
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= patience:
            print("Early stopping due to no improvement in SWA validation loss.")
            break

    # Check for best model state based on validation loss
    if eval_loss < best_val_loss:
        best_val_loss = eval_loss
        best_model_state = model.state_dict()

# Load the best model state
if best_swa_model_state is not None and best_swa_eval_loss < best_val_loss:
    print("Loading the best SWA model state...")
    swa_model.load_state_dict(best_swa_model_state)
    model = swa_model
else:
    print("Loading the best model state...")
    model.load_state_dict(best_model_state)

# Final SWA model evaluation
print("Evaluating final SWA model...")
update_bn(train_loader, swa_model)
model = swa_model
test_loss = 0
for batch in tqdm(test_loader, desc="Testing"):
    with torch.no_grad():
        inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
        labels = batch['labels'].to(device)
        outputs = model(**inputs, labels=labels)
        loss = outputs.loss
        test_loss += loss.item()

test_loss /= len(test_loader)
print(f"Final Test Loss: {test_loss}")

# Predictions on the test set
print("Generating predictions on the test set...")
model.eval()
predictions = []
for batch in tqdm(test_loader, desc="Predicting"):
    with torch.no_grad():
        inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
        outputs = model(**inputs)
        preds = outputs.logits.squeeze().tolist()
        predictions.extend(preds)

# Plot predictions vs true values for test set
print("Plotting predictions vs true values...")
plt.figure(figsize=(8, 6))
plt.scatter(y_test, predictions, alpha=0.6, color='green')
plt.xlabel('True Values')
plt.ylabel('Predictions')
plt.title('Test Set: True Values vs Predictions')
plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'k--', lw=2)
plt.show()


In [None]:
# Save the final model checkpoint with test loss to 2 decimal places
checkpoint_path = "model_checkpoint_{:.2f}.pth".format(test_loss)
torch.save(model.state_dict(), checkpoint_path)
print(f"Model checkpoint saved to {checkpoint_path}")


In [None]:
model.load_state_dict(torch.load(checkpoint_path))
model.eval()

# Define a function to tokenize input and get prediction
def predict(smiles_string):
    # Tokenize the input SMILES string
    inputs = tokenizer(smiles_string, return_tensors="pt", padding=True, truncation=True, max_length=128)
    
    # Move inputs to the appropriate device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Get the model prediction
    with torch.no_grad():
        outputs = model(**inputs)
        prediction = outputs.logits.squeeze().item()
    
    return prediction

# Example usage
smiles_string = "CCO"  # Replace this with your SMILES string
prediction = predict(smiles_string)
print(f"Prediction for {smiles_string}: {prediction}")

In [None]:
y

In [None]:
max([len(s) for s in df["smiles"].tolist()])

In [None]:
check_list = [(len(s),s) for s in df["smiles"].tolist()]
check_list.sort(key=lambda x:x[0],reverse=True)

In [None]:
tokenizer(check_list[0][1], return_tensors="pt", padding=True, truncation=True, max_length=128)

In [None]:
tokenizer(check_list[0][1], return_tensors="pt", padding=True, truncation=True, max_length=128)['input_ids']

In [None]:
tokenizer(check_list[0][1], return_tensors="pt", padding=True, truncation=True, max_length=512)['input_ids']

In [None]:
import torch
from torch.utils.data import DataLoader, TensorDataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.optim import Adam
from transformers import get_linear_schedule_with_warmup
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
from tqdm import tqdm
import matplotlib.pyplot as plt

# Load the ChemBERTa model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
model = AutoModelForSequenceClassification.from_pretrained("seyonec/ChemBERTa-zinc-base-v1", num_labels=1)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Load the data
print("Loading data...")
df = dm.data.freesolv()
X, y = df["smiles"], df["expt"]

# Split the data into training, validation, and test sets
print("Splitting data into train, validation, and test sets...")
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.4, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)

# Augment the training data with randomized SMILES strings
print("Augmenting training data...")
new_X, new_y = [], []
listx = X_train.tolist()
listy = y_train.tolist()
for i in range(len(X_train)):
    for j in range(100):
        new_X.append(sme.randomize_smiles(listx[i]))
        new_y.append(listy[i])

# Convert augmented data to DataFrame and Series
X_train = pd.DataFrame(new_X).squeeze()
y_train = pd.DataFrame(new_y).squeeze()

# Tokenize the datasets
print("Tokenizing datasets...")
train_encodings = tokenizer(list(X_train), padding=True, truncation=True, max_length=128, return_tensors="pt")
val_encodings = tokenizer(list(X_val), padding=True, truncation=True, max_length=128, return_tensors="pt")
test_encodings = tokenizer(list(X_test), padding=True, truncation=True, max_length=128, return_tensors="pt")

# Convert labels to tensors
y_train_tensor = torch.tensor(y_train.values, dtype=torch.float32)
y_val_tensor = torch.tensor(y_val.values, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test.values, dtype=torch.float32)

# Create TensorDataset instances
print("Creating datasets...")
train_dataset = TensorDataset(train_encodings['input_ids'], train_encodings['attention_mask'], y_train_tensor)
val_dataset = TensorDataset(val_encodings['input_ids'], val_encodings['attention_mask'], y_val_tensor)
test_dataset = TensorDataset(test_encodings['input_ids'], test_encodings['attention_mask'], y_test_tensor)

# Create DataLoader for batching
print("Creating DataLoader for batching...")
train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1024)
test_loader = DataLoader(test_dataset, batch_size=1024)

# Optimizer, scheduler and SWA setup
optimizer = Adam(model.parameters(), lr=1e-5)
total_steps = len(train_loader) * 1000
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)
swa_model = AveragedModel(model)
swa_start = 5  # Epoch to start SWA
swa_scheduler = SWALR(optimizer, swa_lr=1e-5)

best_val_loss = float('inf')
best_swa_eval_loss = float('inf')
best_model_state = None
best_swa_model_state = None
patience = 10
epochs_no_improve = 0

# Training loop
print("Starting training loop...")
for epoch in range(1000):
    model.train()
    total_loss = 0
    for batch in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}"):
        optimizer.zero_grad()
        input_ids, attention_mask, labels = (item.to(device) for item in batch)
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        scheduler.step()
        total_loss += loss.item()
    
    print(f"Epoch {epoch + 1} - Training Loss: {total_loss / len(train_loader)}")

    model.eval()
    eval_loss = 0
    for batch in tqdm(val_loader, desc=f"Validation Epoch {epoch + 1}"):
        with torch.no_grad():
            input_ids, attention_mask, labels = (item.to(device) for item in batch)
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            eval_loss += loss.item()
    
    eval_loss /= len(val_loader)
    print(f"Epoch {epoch + 1} - Validation Loss: {eval_loss}")

    swa_eval_loss = None
    if epoch >= swa_start:
        swa_model.update_parameters(model)
        swa_scheduler.step()
        update_bn(train_loader, swa_model)
        swa_eval_loss = 0
        for batch in tqdm(val_loader, desc=f"SWA Validation Epoch {epoch + 1}"):
            with torch.no_grad():
                input_ids, attention_mask, labels = (item.to(device) for item in batch)
                outputs = swa_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs.loss
                swa_eval_loss += loss.item()
        
        swa_eval_loss /= len(val_loader)
        print(f"SWA Model Evaluation after epoch {epoch + 1} - Validation Loss: {swa_eval_loss}")

        # Check for early stopping and best model state after SWA has started
        if swa_eval_loss < best_swa_eval_loss:
            best_swa_eval_loss = swa_eval_loss
            best_swa_model_state = swa_model.state_dict()
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= patience:
            print("Early stopping due to no improvement in SWA validation loss.")
            break

    # Check for best model state based on validation loss
    if eval_loss < best_val_loss:
        best_val_loss = eval_loss
        best_model_state = model.state_dict()

# Load the best model state
if best_swa_model_state is not None and best_swa_eval_loss < best_val_loss:
    print("Loading the best SWA model state...")
    swa_model.load_state_dict(best_swa_model_state)
    model = swa_model
else:
    print("Loading the best model state...")
    model.load_state_dict(best_model_state)

# Final SWA model evaluation
print("Evaluating final SWA model...")
update_bn(train_loader, swa_model)
model = swa_model
test_loss = 0
for batch in tqdm(test_loader, desc="Testing"):
    with torch.no_grad():
        input_ids, attention_mask, labels = (item.to(device) for item in batch)
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        test_loss += loss.item()

test_loss /= len(test_loader)
print(f"Final Test Loss: {test_loss}")

# Predictions on the test set
print("Generating predictions on the test set...")
model.eval()
predictions = []
for batch in tqdm(test_loader, desc="Predicting"):
    with torch.no_grad():
        input_ids, attention_mask = (item.to(device) for item in batch[:2])
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        preds = outputs.logits.squeeze().tolist()
        predictions.extend(preds)

# Plot predictions vs true values for test set
print("Plotting predictions vs true values...")
plt.figure(figsize=(8, 6))
plt.scatter(y_test, predictions, alpha=0.6, color='green')
plt.xlabel('True Values')
plt.ylabel('Predictions')
plt.title('Test Set: True Values vs Predictions')
plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'k--', lw=2)
plt.show()
