In [None]:

import os
import sys
import torch.backends.cudnn as cudnn
import yaml
from utils import AttrDict
import pandas as pd
import numpy as np
from train import train

In [None]:
cudnn.benchmark = True
cudnn.deterministic = False

In [None]:
def get_config(file_path):
    with open(file_path, 'r', encoding="utf8") as stream:
        opt = yaml.safe_load(stream)
    opt = AttrDict(opt)
    if opt.lang_char == 'None':
        characters = ''
        for data in opt['select_data'].split('-'):
            csv_path = os.path.join(opt['train_data'], data, 'labels.csv')
            df = pd.read_csv(csv_path, sep='^([^,]+),', engine='python', usecols=['filename', 'words'], keep_default_na=False)
            all_char = ''.join(df['words'])
            characters += ''.join(set(all_char))
        characters = sorted(set(characters))
        opt.character= ''.join(characters)
    else:
        opt.character = opt.number + opt.symbol + opt.lang_char
    os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True)
    return opt

In [None]:
opt = get_config("config_files/en_filtered_config.yaml")
train(opt, amp=True)

In [None]:
import torch
import torchvision
import numpy
print(torch.__version__, torchvision.__version__, numpy.__version__)


In [None]:
import re
import matplotlib.pyplot as plt

# Path to your log file
log_file = "saved_models/ne_CTC/log_train.txt"

# Read file
with open(log_file, "r", encoding="utf-8") as f:
    text = f.read()

# Extract losses using regex
pattern = r"\[(\d+)/\d+\]\s*Train loss:\s*([\d.]+),\s*Valid loss:\s*([\d.]+)"
matches = re.findall(pattern, text)

# Parse values
steps = [int(m[0]) for m in matches]
train_loss = [float(m[1]) for m in matches]
valid_loss = [float(m[2]) for m in matches]

# Plot
plt.figure(figsize=(8, 5))
plt.plot(steps, train_loss, label="Train Loss", linewidth=2)
plt.plot(steps, valid_loss, label="Validation Loss", linewidth=2)
plt.xlabel("Training Step")
plt.ylabel("Loss")
plt.title("Training vs Validation Loss (ResNet-BiLSTM-CTC)")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()
