In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, AdamW
import torch

# Load the tokenizer and model

model_name_or_id = "AI4Chem/CHEMLLM-2b-1_5"  # openai-community/gpt2-large", "AI4Chem/ChemLLM-20B-Chat-SFT"

# Check if CUDA is available and set the device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(f"model/{model_name_or_id}", torch_dtype=torch.float16, trust_remote_code=True).to(device)
tokenizer = AutoTokenizer.from_pretrained(f"tokenizer/{model_name_or_id}", trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token 


# Function to concatenate user and assistant messages into a single text string
def concatenate_messages(dataset):
    conversation = ""
    for message in dataset["messages"]:
        if message["role"] == "user":
            conversation += "User: " + message["content"] + "\n"
        elif message["role"] == "assistant":
            conversation += "Assistant: " + message["content"] + "\n"
    return conversation

# Prepare the texts by concatenating messages
texts = [concatenate_messages(dataset) for dataset in datasets]

# Tokenize the concatenated texts
tokenized_inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")

# Move the model and inputs to the correct device (GPU or CPU)
input_ids = tokenized_inputs.input_ids.to(device)
attention_mask = tokenized_inputs.attention_mask.to(device)

# Define the optimizer
optimizer = AdamW(model.parameters(), lr=5e-5)

# Fine-tuning loop
num_epochs = 3  # Number of epochs for training
model.train()
for epoch in range(num_epochs):
    optimizer.zero_grad()
    outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
    loss = outputs.loss
    loss.backward()
    optimizer.step()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")

# Save the fine-tuned model
model.save_pretrained("fine_tuned_model")
tokenizer.save_pretrained("fine_tuned_model")

# Generate a response using the fine-tuned model
model.eval()

# Prepare an example input
example_input = "User: Can you suggest a polymer with Tg=300 and Er=200?\nAssistant:"

# Tokenize the example input
example_input_ids = tokenizer(example_input, return_tensors="pt").input_ids.to(device)

# Generate a response
generated_ids = model.generate(example_input_ids, max_new_tokens=50, pad_token_id=tokenizer.eos_token_id)

# Decode and print the response
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print("Generated Response:")
print(generated_text)
