In [1]:
import os
from models.models import get_model
from scripts.trainer import Trainer
from scripts.training_args import TrainingArguments
from scripts.dataloader import DataGen, create_dataloaders

data_path = './data/azh_wound_care_center_dataset_patches/'
model_name = "unet"
num_train_epochs = 10
batch_size = 8
learning_rate = 0.001
logdir = "logs"
checkpoint_path = "checkpoints"
device = "cuda"

expt_name = "temp"
expt_description = ""

args = {
    "data": data_path,
    "model": model_name,
    "num_train_epochs": num_train_epochs,
    "batch_size": batch_size,
    "learning_rate": learning_rate,
    "logdir": logdir,
    "checkpoint_path": checkpoint_path,
    "device": device,
    "expt_name": expt_name,
}

In [2]:
# Create model
model = get_model(args["model"], args, device=device)

# Create data loader
data_gen = DataGen(os.path.join(os.getcwd(), args["data"]), split_ratio=0.2)
train_loader, validation_loader, test_loader = create_dataloaders(data_gen, args["batch_size"], args["device"])

# Create trainer
trainer = Trainer(
    model,
    train_loader=train_loader,
    validation_loader=validation_loader,
    test_loader=test_loader,
    args=TrainingArguments(**args)
)

# Train the model
trainer.train()
trainer.evaluate()

logging path:  logs/temp/20240309_124754
checkpoint path:  checkpoints/temp/20240309_124754


Epoch 2/10:  10%|[32m█         [0m| 1/10 [02:54<26:10, 174.47s/it]


Epoch 1/10 Train loss: 72.0734 Validation loss: 11.6424


Epoch 3/10:  20%|[32m██        [0m| 2/10 [05:45<23:00, 172.51s/it]


Epoch 2/10 Train loss: 37.3199 Validation loss: 11.7214


Epoch 4/10:  30%|[32m███       [0m| 3/10 [08:36<20:02, 171.75s/it]


Epoch 3/10 Train loss: 37.0169 Validation loss: 9.9381


Epoch 5/10:  40%|[32m████      [0m| 4/10 [11:27<17:08, 171.43s/it]


Epoch 4/10 Train loss: 43.8004 Validation loss: 10.4916


Epoch 6/10:  50%|[32m█████     [0m| 5/10 [14:18<14:17, 171.42s/it]


Epoch 5/10 Train loss: 35.5968 Validation loss: 10.1043


Epoch 7/10:  60%|[32m██████    [0m| 6/10 [17:10<11:25, 171.35s/it]


Epoch 6/10 Train loss: 34.6669 Validation loss: 9.0828


Epoch 8/10:  70%|[32m███████   [0m| 7/10 [20:02<08:35, 171.86s/it]


Epoch 7/10 Train loss: 30.2600 Validation loss: 8.6848


Epoch 8/10:  70%|[32m███████   [0m| 7/10 [20:45<08:53, 177.92s/it]


KeyboardInterrupt: 

In [1]:
import numpy as np

In [10]:
arrays = [np.random.rand(1,4,4) for x in range(10)]
np.stack(arrays).squeeze().shape

(10, 4, 4)