In [None]:
import re
import matplotlib.pyplot as plt

# Regex to capture training metrics:
# Group 1: epoch
# Group 2: total_epochs
# Group 3: step
# Group 4: AE_loss
# Group 5: D_loss
# Group 6: G_loss
# Group 7: Enc Grad
# Group 8: Dec Grad
# Group 9: Disc Grad
# Group 10: Gen Grad
training_line_regex = re.compile(
    r"\[Epoch (\d+)\/(\d+)\] \[Step (\d+)\] "
    r"AE_loss\(EMD \* [^\)]+\): ([\-\d\.]+) \| "
    r"D_loss: ([\-\d\.]+) \| "
    r"G_loss: ([\-\d\.]+) \| "
    r"Enc Grad: ([\-\d\.]+) \| "
    r"Dec Grad: ([\-\d\.]+) \| "
    r"Disc Grad: ([\-\d\.]+) \| "
    r"Gen Grad: ([\-\d\.]+) .*"
)

# Regex to capture validation lines:
# Group 1: epoch
# Group 2: val_loss
val_line_regex = re.compile(
    r"Validation AE_loss at epoch (\d+): ([\-\d\.]+)"
)

def parse_training_log(log_file_path="training_log.txt"):
    epochs = []
    steps = []
    ae_losses = []
    d_losses = []
    g_losses = []
    enc_grads = []
    dec_grads = []
    disc_grads = []
    gen_grads = []
    
    val_epochs = []
    val_ae_losses = []

    with open(log_file_path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            
            # Check for training lines
            train_match = training_line_regex.search(line)
            if train_match:
                epoch = int(train_match.group(1))
                # total_epochs = int(train_match.group(2)) # might not be necessary
                step = int(train_match.group(3))
                ae_loss = float(train_match.group(4))
                d_loss = float(train_match.group(5))
                g_loss = float(train_match.group(6))
                enc_grad = float(train_match.group(7))
                dec_grad = float(train_match.group(8))
                disc_grad = float(train_match.group(9))
                gen_grad = float(train_match.group(10))
                
                epochs.append(epoch)
                steps.append(step)
                ae_losses.append(ae_loss)
                d_losses.append(d_loss)
                g_losses.append(g_loss)
                enc_grads.append(enc_grad)
                dec_grads.append(dec_grad)
                disc_grads.append(disc_grad)
                gen_grads.append(gen_grad)

            # Check for validation lines
            val_match = val_line_regex.search(line)
            if val_match:
                val_epoch = int(val_match.group(1))
                val_loss = float(val_match.group(2))
                
                val_epochs.append(val_epoch)
                val_ae_losses.append(val_loss)
    
    return {
        "epoch": epochs,
        "step": steps,
        "AE_loss": ae_losses,
        "D_loss": d_losses,
        "G_loss": g_losses,
        "Enc_Grad": enc_grads,
        "Dec_Grad": dec_grads,
        "Disc_Grad": disc_grads,
        "Gen_Grad": gen_grads,
        "val_epoch": val_epochs,
        "val_AE_loss": val_ae_losses
    }


def plot_metrics(metrics_dict):
    # Basic line plots for training metrics
    steps = metrics_dict["step"]
    plt.figure(figsize=(12, 8))

    # Plot AE_loss
    plt.subplot(2, 2, 1)
    plt.plot(steps, metrics_dict["AE_loss"], label="AE_loss")
    plt.xlabel("Step")
    plt.ylabel("AE_loss")
    plt.title("Reconstruction Loss (EMD) over Steps")
    plt.legend()

    # Plot D_loss vs G_loss
    plt.subplot(2, 2, 2)
    plt.plot(steps, metrics_dict["D_loss"], label="D_loss", color="red")
    plt.plot(steps, metrics_dict["G_loss"], label="G_loss", color="green")
    plt.xlabel("Step")
    plt.ylabel("Loss")
    plt.title("Discriminator vs Generator Loss")
    plt.legend()

    # Plot gradient norms (Enc/Dec)
    plt.subplot(2, 2, 3)
    plt.plot(steps, metrics_dict["Enc_Grad"], label="Enc_Grad")
    plt.plot(steps, metrics_dict["Dec_Grad"], label="Dec_Grad")
    plt.xlabel("Step")
    plt.ylabel("Gradient Norm")
    plt.title("Encoder/Decoder Gradient Norms")
    plt.legend()

    # Plot gradient norms (Disc/Gen)
    plt.subplot(2, 2, 4)
    plt.plot(steps, metrics_dict["Disc_Grad"], label="Disc_Grad", color="orange")
    plt.plot(steps, metrics_dict["Gen_Grad"], label="Gen_Grad", color="purple")
    plt.xlabel("Step")
    plt.ylabel("Gradient Norm")
    plt.title("Discriminator/Generator Gradient Norms")
    plt.legend()

    plt.tight_layout()
    plt.show()

def plot_validation(metrics_dict):
    # Plot validation AE_loss vs epoch
    val_epochs = metrics_dict["val_epoch"]
    val_ae_loss = metrics_dict["val_AE_loss"]

    if len(val_epochs) == 0:
        print("No validation data found in log.")
        return

    plt.figure(figsize=(6,4))
    plt.plot(val_epochs, val_ae_loss, marker='o', label="Val_AE_loss")
    plt.xlabel("Epoch")
    plt.ylabel("Validation AE_loss")
    plt.title("Validation Reconstruction Loss Over Epochs")
    plt.legend()
    plt.show()

In [None]:

# 1. Parse the log
metrics = parse_training_log(r"\Generative_3DWheatNet\training_log.txt")

# 2. Plot the training metrics
plot_metrics(metrics)

# 3. Plot the validation metrics
plot_validation(metrics)