In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from trnbl.training_manager import TrainingManager
from trnbl.loggers.local import LocalLogger

In [2]:
class Model(nn.Module):
	def __init__(self) -> None:
		super(Model, self).__init__()
		self.fc: nn.Linear = nn.Linear(1, 1)

	def forward(self, x: torch.Tensor) -> torch.Tensor:
		return self.fc(x)

class MockedDataset(torch.utils.data.Dataset):
	def __init__(
			self,
			length: int,
			channels: int = 2,
		) -> None:
		self.dataset = torch.randn(length, channels, 1)

	def __getitem__(self, idx: int):
		return self.dataset[idx][0], self.dataset[idx][1]

	def __len__(self):
		return len(self.dataset)


In [5]:
model = Model()
optimizer = optim.SGD(model.parameters(), lr=0.1)
criterion = nn.MSELoss()

logger = LocalLogger(
	project="integration-tests",
	metric_names=["train/loss", "train/acc", "val/loss", "val/acc"],
	train_config=dict(
		model=str(model),
		dataset="dummy",
		optimizer=str(optimizer),
		criterion=str(criterion),
	),
	base_path="../tests/_temp",
)

	
train_loader: DataLoader = DataLoader(MockedDataset(100), batch_size=10)

with TrainingManager(
	model=model,
	logger=logger,
	evals={
		"1 epochs": lambda model: {'wgt_mean': torch.mean(model.fc.weight).item()},
		"1/2 epochs": lambda model: logger.get_mem_usage(),
	}.items(),
	checkpoint_interval="50 epochs",
) as tr:

	# Training loop
	for epoch in tr.epoch_loop(range(10), use_tqdm=True):
		for inputs, targets in tr.batch_loop(train_loader, use_tqdm=True):
			optimizer.zero_grad()
			outputs = model(inputs)
			loss = criterion(outputs, targets)
			loss.backward()
			optimizer.step()

			accuracy = torch.sum(torch.argmax(outputs, dim=1) == targets).item() / len(targets)
			
			tr.batch_update(
				samples=len(targets),
				**{"train/loss": loss.item(), "train/acc": accuracy},
			)

# starting logger with id main-h40181-240723_0034-radoza
# starting training manager initialization


  0%|          | 0/10 [00:00<?, ? epochs/s]

# initialized training manager


epoch 1/10: 100%|██████████| 10/10 [00:00<00:00, 28.81 batches/s]
epoch 2/10: 100%|██████████| 10/10 [00:00<00:00, 27.72 batches/s]
epoch 3/10: 100%|██████████| 10/10 [00:00<00:00, 29.44 batches/s]
epoch 4/10: 100%|██████████| 10/10 [00:00<00:00, 30.79 batches/s]
epoch 5/10: 100%|██████████| 10/10 [00:00<00:00, 30.84 batches/s]
epoch 6/10: 100%|██████████| 10/10 [00:00<00:00, 30.53 batches/s]
epoch 7/10: 100%|██████████| 10/10 [00:00<00:00, 29.40 batches/s]
epoch 8/10: 100%|██████████| 10/10 [00:00<00:00, 29.70 batches/s]
epoch 9/10: 100%|██████████| 10/10 [00:00<00:00, 29.79 batches/s]
epoch 10/10: 100%|██████████| 10/10 [00:00<00:00, 29.57 batches/s]
100%|██████████| 10/10 [00:03<00:00,  2.93 epochs/s]

# training complete
# closing logger



