In [None]:
import json

### Change depending on which model to plot loss for
# path = "./llama-ft"
path = "./mistral-ft"

loss = []
val = []

with open(f"{path}/checkpoint-375/trainer_state.json") as f:
    item = json.load(f)
    history = (item["log_history"])

    for entry in history:
        if "loss" in entry:
            loss.append({"loss": entry["loss"], "epoch": entry["epoch"], "step": entry["step"]})
        if "eval_loss" in entry:
            val.append({"eval_loss": entry["eval_loss"], "epoch": entry["epoch"], "step": entry["step"]})

In [None]:
import matplotlib.pyplot as plt

# Training loss
losses = [l['loss'] for l in loss]
epochs = [l['epoch'] for l in loss]
steps = [l['step'] for l in loss]

# Validation loss
val_losses = [v['eval_loss'] for v in val]
val_epochs = [v['epoch'] for v in val]
val_steps = [v['step'] for v in val]

plt.figure(figsize=(8, 5))
plt.plot(steps, losses, marker=',', label='Training Loss')
plt.plot(val_steps, val_losses, 'ro', label='Validation Loss')


handles, labels = plt.gca().get_legend_handles_labels()
by_label = dict(zip(labels, handles))
plt.legend(by_label.values(), by_label.keys())

plt.title('Mistral Fine-Tuning Loss')
plt.xlabel('Steps')
plt.ylabel('Loss (%)')
plt.grid(which='both', linewidth=0.5)
plt.show()

In [None]:
# Clear plot
plt.clf()

import numpy as np

# Manual data
labels = ["Base", "Fine-Tuned", "RAG1", "RAG2", "FT-RAG1", "FT-RAG2"]
runtime_llama3 = [23.16, 27.80, 55.56, 51.40, 61.61, 56.16]
runtime_mistral = [96.63, 29.46, 122.64, 142.23, 71.20, 67.84]

# Find x-pos for labels
x = np.arange(len(labels))

# Bar width
width = 0.35

fig, ax = plt.subplots(figsize=(8, 5))
llama = ax.bar(x - width/2, runtime_llama3, width, label='Llama3', color='tomato')
mistral = ax.bar(x + width/2, runtime_mistral, width, label='Mistral', color='royalblue')

ax.set_xlabel('Configuration')
ax.set_ylabel('Runtime in Minutes')
ax.set_title('Inference Runtime')
ax.set_xticks(x)
ax.set_xticklabels(labels)
ax.legend()
ax.tick_params(axis='x')
plt.grid(axis="y", linewidth=0.5)


plt.show()

In [None]:
# Clear plot
plt.clf()

import numpy as np

# Manual data
labels = ["Base", "Fine-Tuned", "RAG1", "RAG2", "FT-RAG1", "FT-RAG2"]
mem_llama3 = [6.12, 6.23, 7.81, 8.43, 9.97, 8.78] 
mem_mistral = [5.21, 5.24, 6.54, 6.07, 6.75, 6.27]

# Find x-pos for labels
x = np.arange(len(labels))

# Bar width
width = 0.35

fig, ax = plt.subplots(figsize=(8, 5))
llama = ax.bar(x - width/2, mem_llama3, width, label='Llama3', color='tomato')
mistral = ax.bar(x + width/2, mem_mistral, width, label='Mistral', color='royalblue')

ax.set_xlabel('Configuration')
ax.set_ylabel('Memory Usage (GB)')
ax.set_title('Inference Memory Usage')
ax.set_xticks(x)
ax.set_xticklabels(labels)
ax.legend()
ax.tick_params(axis='x')
plt.grid(axis="y", linewidth=0.5)


plt.show()