In [None]:
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import random
import numpy as np
import torch

# Modify this for your machine
DEVICE = "cuda" if torch.cuda.is_available() else "mps"
MAIN_PATH = (
    "/content/drive/MyDrive/ecg-timeseries-model/models/" if DEVICE == "cuda" else ""
)

train_df = pd.read_pickle(MAIN_PATH + "train_df.pkl")
val_df = pd.read_pickle(MAIN_PATH + "val_df.pkl")

INPUT_LENGTH = 9000
BATCH_SIZE = 64
EPOCHS = 10
NUM_WORKERS = 8 if DEVICE == "cuda" else 0
NUM_CLASSES = 4

CLASS_NAMES = {0: "Normal", 1: "AF", 2: "Other", 3: "Noisy"}


TRAINING_MEAN = 1.07e-09
TRAINING_STD = 175.11


def set_seed(seed=42):
    random.seed(seed)  # Python random
    np.random.seed(seed)  # NumPy
    torch.manual_seed(seed)  # PyTorch CPU
    torch.cuda.manual_seed(seed)  # PyTorch GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


set_seed(42)
print("Seed set for reproducibility.")

Seed set for reproducibility.


In [281]:
# Augmentation functions
import scipy.signal
import torch
import numpy as np
import random
from torch.nn.utils.rnn import pack_sequence

# wrapping them in nn.Module should not introduce huge overhead


def time_shift(signal, shift_range=(-100, 100)):
    shift = np.roll(signal, shift=random.randint(shift_range[0], shift_range[1]))
    return shift


def add_noise(signal, noise_level=0.1):
    noise = np.random.normal(0, noise_level, signal.shape)
    return signal + noise


def time_warp(signal, warp_factor=0.1):
    return scipy.signal.resample(
        signal, int(len(signal) * (1 + np.random.uniform(-warp_factor, warp_factor)))
    )


def amplitude_scaling(signal, scale_range=(0.8, 1.2)):
    scale = random.uniform(scale_range[0], scale_range[1])
    return signal * scale


def augment(signal, augmentation="all"):
    if augmentation in ("all", "warp_only") and np.random.rand() < 0.5:
        signal = time_warp(signal)
    if augmentation in ("all", "noise_only") and np.random.rand() < 0.5:
        signal = add_noise(signal)
    if augmentation in ("all", "shift_only") and np.random.rand() < 0.5:
        signal = time_shift(signal)
    if augmentation in ("all", "scale_only") and np.random.rand() < 0.5:
        signal = amplitude_scaling(signal)
    return signal


class ECGDataset(torch.utils.data.Dataset):
    def __init__(self, df, augmentation=None):
        self.df = df
        self.augmentation = augmentation

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        signal = self.df.iloc[idx]["signal"]
        signal = np.array(signal, dtype=np.float32)
        signal = augment(signal, self.augmentation) 
        signal = torch.tensor(signal, dtype=torch.float32)
        
        label = int(self.df.iloc[idx]["label"])
        label = torch.tensor(label, dtype=torch.long)

        length = len(signal)

        return signal, label, length


def collate_fn(batch):
    # TODO: Optimization: sort by length
    signals, labels, lengths = zip(*batch)
    signals = [s.unsqueeze(-1) if s.dim() == 1 else s for s in signals]
    packed_signals = pack_sequence(signals, enforce_sorted=False)

    lengths = torch.tensor(lengths, dtype=torch.int64)
    labels = torch.stack(labels)

    return packed_signals, labels, lengths

In [282]:
import torch
import torch.nn as nn
import torch.nn.functional as F


@torch.jit.script
def batch_apply_stft(
    signals: torch.Tensor, 
    lengths: torch.Tensor, 
    n_fft: int = N_FFT,  
    hop_length: int = HOP_LENGTH,  
    pad_mode: str = "constant"
) -> tuple[torch.Tensor, torch.Tensor]:
    
    batch_stft = torch.stft(
        signals.squeeze(-1),
        n_fft,
        hop_length,
        window=torch.hann_window(n_fft, device=signals.device),
        pad_mode=pad_mode,
        return_complex=True,
    )
    magnitude = torch.abs(batch_stft)

    new_lengths = torch.clamp((lengths - n_fft) // hop_length + 1, min=1)

    return magnitude, new_lengths

def he_init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight, mode="fan_in", nonlinearity="relu")
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)

In [283]:
# add data parallelism support
import torch.nn as nn
from torch.nn.utils.rnn import PackedSequence


class VorgabeRNN(nn.Module):
    def __init__(
        self,
        hidden_size=50,
        num_layers=2,
        num_classes=NUM_CLASSES,
        n_fft=512,
        hop_length=256,
        dropout_rate=0.2,
    ):
        super(VorgabeRNN, self).__init__()

        # STFT parameters
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.n_freqs = n_fft // 2 + 1 

        self.conv1 = nn.Conv2d(1, 32, kernel_size=(3, 3), padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=(3, 3), padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(dropout_rate)

        self.conv_output_size = 64 * (self.n_freqs // 4)  

        self.rnn = nn.RNN(self.conv_output_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.stft = batch_apply_stft
        self.apply(he_init_weights)

        self.rnn_type = type(self.rnn).__name__

    def feature_extractor(self, x: PackedSequence, lengths: torch.Tensor) -> torch.Tensor:
        lengths = lengths.detach().clone().cpu()  

        x, _ = torch.nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
        x, lengths = self.stft(x, lengths, n_fft=self.n_fft, hop_length=self.hop_length)

        x = torch.log2(x + 1e-8)
        x = x.unsqueeze(1)

        x = self.pool(F.relu(self.conv1(x)))
        x = self.dropout(x)
        x = self.pool(F.relu(self.conv2(x)))
        x = self.dropout(x)

        batch_size, channels, freq_bins, time_frames = x.shape
        x = x.view(batch_size, time_frames, -1)  

        lengths = lengths // 4
        lengths = torch.clamp(lengths, min=1)

        x = torch.nn.utils.rnn.pack_padded_sequence(
            x, lengths, batch_first=True, enforce_sorted=False
        )

        if self.rnn_type == "LSTM":
            LSTM_output, LSTM_states = self.rnn(x)
            return LSTM_states[0][-1]
        else:
            RNN_output, RNN_states = self.rnn(x)
            return RNN_states[-1]

    def forward(self, x: PackedSequence, lengths: torch.Tensor) -> torch.Tensor:
        x = self.feature_extractor(x, lengths)
        x = self.fc(x)
        return x

    def predict(self, x: PackedSequence, lengths: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            x = self.forward(x, lengths)
            _, predicted = torch.max(x, 1)
        return predicted

In [284]:
class VorgabeLSTM(VorgabeRNN):
    def __init__(self, hidden_size=50, num_layers=2, num_classes=NUM_CLASSES, n_fft=512, hop_length=256, dropout_rate=0.2):
        super().__init__(hidden_size, num_layers, num_classes, n_fft, hop_length, dropout_rate)
        self.rnn = nn.LSTM(self.conv_output_size, hidden_size, num_layers, batch_first=True)

class VorgabeGRU(VorgabeRNN):
    def __init__(self, hidden_size=50, num_layers=2, num_classes=NUM_CLASSES, n_fft=512, hop_length=256, dropout_rate=0.2):
        super().__init__(hidden_size, num_layers, num_classes, n_fft, hop_length, dropout_rate)
        self.rnn = nn.GRU(self.conv_output_size, hidden_size, num_layers, batch_first=True)


In [285]:
import sklearn.svm
class RNNwithSVM(VorgabeRNN):
    def __init__(
        self,
        hidden_size=50,
        num_layers=2,
        num_classes=NUM_CLASSES,
        n_fft=512,
        hop_length=256,
        dropout_rate=0.2,
    ):
        super(RNNwithSVM, self).__init__(
            hidden_size, num_layers, num_classes, n_fft, hop_length, dropout_rate
        )
        self.fc1 = nn.Linear(hidden_size, hidden_size) 
        self.fc2 = nn.Linear(hidden_size, num_classes)
        self.svm = sklearn.svm.SVC()
        self.is_svm_trained = False

    def forward(self, x: PackedSequence, lengths: torch.Tensor) -> torch.Tensor:
        x = self.feature_extractor(x, lengths)[-1]
        x = self.fc1(x)
        if self.fc2 is not None and not self.is_svm_trained:
            return self.fc2(x)
        return x

    def predict(self, x: PackedSequence, lengths: torch.Tensor) -> torch.Tensor:
        if not self.is_svm_trained:
            raise ValueError("SVM not trained yet!")
        self.eval()
        with torch.no_grad():
            x = self.feature_extractor(x, lengths)[-1]
            x = self.fc1(x)
            x = self.svm.predict(x.cpu().numpy())
        return x
    def _remove_temp_classifier(self):
        self.fc2 = None



In [294]:
from torch.utils.data import DataLoader

train_dataset = ECGDataset(train_df, augmentation=None)
val_dataset = ECGDataset(val_df)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    generator=torch.Generator().manual_seed(42),
    collate_fn=collate_fn,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    collate_fn=collate_fn,
)

# check the actual signal lengths from the paddedsequences after   x, _ = torch.nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
for signals, labels, lengths in val_loader:
    print(f"Batch size: {signals.batch_sizes}")
    print(f"Signal lengths: {lengths}")
    # signals is a PackedSequence
    print(f"Signal shapes: {[s.shape for s in signals]}")  # Print shapes of each signal in the batch
    x, _ = torch.nn.utils.rnn.pad_packed_sequence(signals, batch_first=True)
    print(f"actual signal lengths: {x.shape[1]}")
    break

Batch size: tensor([64, 64, 64,  ..., 10, 10, 10])
Signal lengths: tensor([18000,  9000,  9000,  9000,  9000,  4500,  9000, 14648, 11520,  9000,
         8226,  5834,  9000,  9000,  9000,  9000,  9000,  9000,  9000, 18000,
         9000,  9000, 18000,  9000,  9000, 10582,  9000,  9000,  9000,  9000,
         9000,  9000,  9000, 18000,  9000,  9000,  9000, 18000, 18000,  9000,
        18000, 18000,  9000,  9000,  9000,  9000,  9000,  9000,  9000,  9000,
         9000,  6162,  9000,  9000,  9000,  9000,  9000,  9000,  9000, 18000,
         9000,  9000, 18000,  9000])
Signal shapes: [torch.Size([664472, 1]), torch.Size([18000]), torch.Size([64]), torch.Size([64])]
actual signal lengths: 18000


In [287]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

from sklearn.utils.class_weight import compute_class_weight



def train_model(model, train_loader, epochs=10, lr=0.001, batch_size_factor=4):
    device = "cuda" if torch.cuda.is_available() else "mps"
    model.to(device)

    is_nn_svm = hasattr(model, "svm")

    optimizer = optim.Adam(model.parameters(), lr=lr)
    original_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.array([0, 1, 2, 3]),
    y=train_df['label'].values
)

    # Normalize and scale the weights
    scaled_weights = original_weights / original_weights.max()  # normalize to max=1
    scaled_weights = 0.5 + (scaled_weights * 0.5)  # shrink range to [0.5, 1.0] for balance

    weights_tensor = torch.tensor(scaled_weights, dtype=torch.float32).to(device)

    criterion = nn.CrossEntropyLoss(weight=weights_tensor)

    print("Training RNN feature extractor...")
    model.train()
    for epoch in range(epochs):
        total_loss, correct, total = 0, 0, 0

        for signals, labels, lengths in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            signals, labels, lengths = signals.to(device), labels.to(device), lengths.to(device)

            optimizer.zero_grad()
            outputs = model(signals, lengths)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

        acc = 100.0 * correct / total
        print(
            f"Epoch {epoch+1}: Train Loss: {total_loss/len(train_loader):.4f}, Train Acc: {acc:.2f}%"
        )

    if is_nn_svm:
        model._remove_temp_classifier()
        _train_svm_phase(model, train_loader, device, batch_size_factor) 

    return model



def _train_svm_phase(model, train_loader, device, batch_size_factor):

    print("Training SVM...")
    model.eval()
    all_features, all_labels = [], []
    batch_features, batch_labels = [], []
    assert model.fc2 is None, "Temporary classifier should be removed before SVM training"

    # extract_features_from_loader(model, data_loader, device):
    with torch.no_grad():
        for signals, labels, lengths in tqdm(train_loader, desc="Extracting features"):
            signals, lengths = signals.to(device), lengths.to(device)
            features = model(signals, lengths)

            batch_features.append(features)
            batch_labels.append(labels)

            if len(batch_features) >= batch_size_factor:
                combined_features = torch.cat(batch_features, dim=0).cpu().numpy()
                combined_labels = torch.cat(batch_labels, dim=0).numpy()
                all_features.append(combined_features)
                all_labels.append(combined_labels)
                batch_features, batch_labels = [], []


    if batch_features:
        combined_features = torch.cat(batch_features, dim=0).cpu().numpy()
        combined_labels = torch.cat(batch_labels, dim=0).numpy()
        all_features.append(combined_features)
        all_labels.append(combined_labels)

    final_features = np.concatenate(all_features, axis=0)
    final_labels = np.concatenate(all_labels, axis=0)

    print(f"Training SVM on {len(final_features)} samples")
    model.svm.fit(final_features, final_labels)
    model.is_svm_trained = True

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix


def evaluate_model(model, val_loader, device):
    print("Evaluating model...")
    model.eval()
    all_predictions, all_labels = [], []
    is_nn_svm = hasattr(model, "svm")

    with torch.no_grad():
        for signals, labels, lengths in tqdm(val_loader, desc="Evaluating"):
            signals, lengths = (
                signals.to(device),
                lengths.to(device),
            )

            predictions = model.predict(signals, lengths)
            if not is_nn_svm:
                predictions = predictions.cpu().numpy()

            all_predictions.extend(predictions)
            all_labels.extend(labels.numpy())

    accuracy = accuracy_score(all_labels, all_predictions)
    print(f"Final Validation Accuracy: {accuracy:.4f}")
    print(classification_report(all_labels, all_predictions))

    # confusion matrix
    cm = confusion_matrix(all_labels, all_predictions)
    # normalize heatmap colors in each row for class distribution
    cm_normalized = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]
    sns.heatmap(
        cm_normalized,
        annot=True,
        fmt=".2f",
        cmap="Blues",
        xticklabels=CLASS_NAMES.values(),
        yticklabels=CLASS_NAMES.values(),
    )
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Confusion Matrix")
    plt.show()

    return accuracy

In [None]:
import torch

input_shape = (BATCH_SIZE, INPUT_LENGTH, 1)

for model in [VorgabeRNN, VorgabeLSTM, VorgabeGRU]:

    model_instance = model(
        hidden_size=128,
        num_layers=2,
        num_classes=NUM_CLASSES,
        n_fft=256,
        hop_length=128,
        dropout_rate=0.2,
    )
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model_instance = model_instance.to(device)
    model_instance.eval()
    device = "mps"
    dummy_input = torch.randn(*input_shape, device=device)
    dummy_input = pack_sequence(
        [dummy_input] * input_shape[0], enforce_sorted=False
    )  
    dummy_lengths = torch.tensor(
        [input_shape[1]] * input_shape[0], dtype=torch.int64, device=device
    )
    scripted = torch.jit.script(model_instance)
    print(scripted.graph)
    scripted.save(f"{model.__name__}.pt")

graph(%self : __torch__.___torch_mangle_71.VorgabeRNN,
      %x.1 : NamedTuple(data : Tensor, batch_sizes : Tensor, sorted_indices : Tensor?, unsorted_indices : Tensor?),
      %lengths.1 : Tensor):
  %x.5 : Tensor = prim::CallMethod[name="feature_extractor"](%self, %x.1, %lengths.1) # /var/folders/8y/9y3t89w944n3vzl3tq_mzvs80000gn/T/ipykernel_66268/300297223.py:71:12
  %fc : __torch__.torch.nn.modules.linear.___torch_mangle_37.Linear = prim::GetAttr[name="fc"](%self)
  %x.9 : Tensor = prim::CallMethod[name="forward"](%fc, %x.5) # /var/folders/8y/9y3t89w944n3vzl3tq_mzvs80000gn/T/ipykernel_66268/300297223.py:72:12
  return (%x.9)

graph(%self : __torch__.___torch_mangle_72.VorgabeLSTM,
      %x.1 : NamedTuple(data : Tensor, batch_sizes : Tensor, sorted_indices : Tensor?, unsorted_indices : Tensor?),
      %lengths.1 : Tensor):
  %x.5 : Tensor = prim::CallMethod[name="feature_extractor"](%self, %x.1, %lengths.1) # /var/folders/8y/9y3t89w944n3vzl3tq_mzvs80000gn/T/ipykernel_66268/300297223

In [291]:
def check_gradients(model):
    model.train()  
    device = next(model.parameters()).device

    dummy_sequences = []
    for i in range(BATCH_SIZE):  
        seq = torch.randn(
            1000, 1, device=device, requires_grad=True
        )
        dummy_sequences.append(seq)

    dummy_input = pack_sequence(dummy_sequences, enforce_sorted=False)
    dummy_lengths = torch.tensor(
        [1000] * BATCH_SIZE, dtype=torch.int64, device=device
    )

    output = model(dummy_input, dummy_lengths)
    loss = output.sum()  
    loss.backward()

    for name, param in model.named_parameters():
        if param.grad is not None:
            if param.grad.abs().sum() == 0:
                print(f"Gradients for {name} are zero.")


for model in [VorgabeRNN, VorgabeLSTM, VorgabeGRU]:
    model_instance = model(
        hidden_size=128,
        num_layers=2,
        num_classes=NUM_CLASSES,
        n_fft=256,
        hop_length=128,
        dropout_rate=0.2,
    )
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model_instance = model_instance.to(device)
    check_gradients(model_instance)

Gradients for rnn.weight_hh_l0 are zero.
Gradients for rnn.weight_hh_l1 are zero.
Gradients for rnn.weight_hh_l0 are zero.
Gradients for rnn.weight_hh_l1 are zero.
Gradients for rnn.weight_hh_l0 are zero.
Gradients for rnn.weight_hh_l1 are zero.


In [None]:
test_space = {
    "model": [VorgabeLSTM, VorgabeGRU, VorgabeRNN],
    "augmentation": ["all"],
}

results = []
print("Starting tests with different configurations...")

for aug in test_space["augmentation"]:
    train_set = ECGDataset(train_df, augmentation=aug)
    val_set = ECGDataset(val_df)
    train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, generator=torch.Generator().manual_seed(42), num_workers=NUM_WORKERS, collate_fn=collate_fn)
    val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, collate_fn=collate_fn)

    for model_cls in test_space["model"]:
        print(f"Testing {model_cls.__name__} with {aug}")
        model = model_cls(hidden_size=128, 
                          num_layers=2, 
                          num_classes=NUM_CLASSES, 
                          n_fft=N_FFT, 
                          hop_length=HOP_LENGTH, 
                          dropout_rate=0.2)
        device = "cuda" if torch.cuda.is_available() else "mps"
        model = model.to(device)
        trained = train_model(model, train_loader, epochs=EPOCHS, lr=0.001, batch_size_factor=4)
        acc = evaluate_model(trained, val_loader, device)

        print(f"Validation Accuracy of {model_cls.__name__} with {aug}: {acc:.4f}\n")
        results.append({"model": model_cls.__name__, "augmentation": aug, "accuracy": acc})
        del trained
        torch.cuda.empty_cache()

    del train_set, val_set, train_loader, val_loader

print("All tests completed.")
best = max(results, key=lambda x: x["accuracy"])
print(f"Best Model: {best['model']}, Augmentation: {best['augmentation']}, Accuracy: {best['accuracy']:.4f}")


Starting tests with different configurations...
Testing VorgabeLSTM with all
Training RNN feature extractor...


Epoch 1/20: 100%|██████████| 83/83 [01:18<00:00,  1.06it/s]


Epoch 1: Train Loss: 1.1139, Train Acc: 58.23%


Epoch 2/20: 100%|██████████| 83/83 [01:00<00:00,  1.37it/s]


Epoch 2: Train Loss: 1.0830, Train Acc: 58.76%


Epoch 3/20: 100%|██████████| 83/83 [00:57<00:00,  1.44it/s]


Epoch 3: Train Loss: 1.0788, Train Acc: 58.64%


Epoch 4/20: 100%|██████████| 83/83 [00:57<00:00,  1.43it/s]


Epoch 4: Train Loss: 1.0739, Train Acc: 58.24%


Epoch 5/20: 100%|██████████| 83/83 [00:43<00:00,  1.89it/s]


Epoch 5: Train Loss: 1.0737, Train Acc: 58.34%


Epoch 6/20: 100%|██████████| 83/83 [00:41<00:00,  2.00it/s]


Epoch 6: Train Loss: 1.0685, Train Acc: 58.59%


Epoch 7/20: 100%|██████████| 83/83 [00:44<00:00,  1.88it/s]


Epoch 7: Train Loss: 1.0714, Train Acc: 58.32%


Epoch 8/20: 100%|██████████| 83/83 [00:44<00:00,  1.86it/s]


Epoch 8: Train Loss: 1.0661, Train Acc: 58.38%


Epoch 9/20: 100%|██████████| 83/83 [00:50<00:00,  1.63it/s]


Epoch 9: Train Loss: 1.0680, Train Acc: 58.70%


Epoch 10/20: 100%|██████████| 83/83 [00:47<00:00,  1.73it/s]


Epoch 10: Train Loss: 1.0644, Train Acc: 58.28%


Epoch 11/20: 100%|██████████| 83/83 [01:02<00:00,  1.32it/s]


Epoch 11: Train Loss: 1.0612, Train Acc: 58.15%


Epoch 12/20: 100%|██████████| 83/83 [00:50<00:00,  1.64it/s]


Epoch 12: Train Loss: 1.0622, Train Acc: 58.17%


Epoch 13/20: 100%|██████████| 83/83 [00:42<00:00,  1.93it/s]


Epoch 13: Train Loss: 1.0629, Train Acc: 58.55%


Epoch 14/20: 100%|██████████| 83/83 [00:46<00:00,  1.79it/s]


Epoch 14: Train Loss: 1.0701, Train Acc: 58.09%


Epoch 15/20: 100%|██████████| 83/83 [00:52<00:00,  1.59it/s]


Epoch 15: Train Loss: 1.0697, Train Acc: 58.57%


Epoch 16/20: 100%|██████████| 83/83 [00:50<00:00,  1.64it/s]


Epoch 16: Train Loss: 1.0778, Train Acc: 58.23%


Epoch 17/20:  58%|█████▊    | 48/83 [00:26<00:19,  1.79it/s]


KeyboardInterrupt: 