In [1]:
import mlflow
import mlflow.pytorch
from dataclasses import dataclass, asdict

import torch
import torch.nn as nn
import numpy as np

from time import sleep

In [2]:
print("mlflow module path:", mlflow.__file__)

mlflow module path: d:\Python\Python312\Lib\site-packages\mlflow\__init__.py


In [3]:
@dataclass
class Config:
    EPOCHS: int = 1000
    LEARNING_RATE: int = 0.1

CONFIG = Config()

In [4]:
mlflow.set_experiment('Test Experiment')
def mlflow_run_decorator(run_name=None):
    def decorator(func):
        def wrapper(*args, **kwargs):
            mlflow.start_run(run_name=run_name)
            try:
                result = func(*args, **kwargs)
                mlflow.set_tag("Status", "SUCCEESS")
            except Exception as e:
                mlflow.log_param("Exception", e)
                mlflow.set_tag("Status", "FAIL")
                raise e
            finally:
                mlflow.end_run()
            return result
        return wrapper
    return decorator

In [5]:
X = np.random.rand(100, 1)
y = 1 + 2 * X + .1 * np.random.randn(100, 1)

In [6]:
class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression, self).__init__()
        self.linear = nn.Linear(1, 1)

    def forward(self, x):
        return self.linear(x)

In [7]:
model = LinearRegression()
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=CONFIG.LEARNING_RATE)

In [8]:
@mlflow_run_decorator(run_name='Check Sync')
def train_model():
    mlflow.log_params(asdict(CONFIG))
    
    for epoch in range(CONFIG.EPOCHS):
        inputs = torch.from_numpy(X.astype(np.float32))
        labels = torch.from_numpy(y.astype(np.float32))

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        mlflow.log_metric('loss', loss.item(), step=epoch)
        
        if epoch % 100 == 0:
            print(f'Epoch {epoch}/{CONFIG.EPOCHS}, Loss: {loss.item()}')
    
        sleep(1)

In [9]:
train_model()

Epoch 0/1000, Loss: 3.7179665565490723


KeyboardInterrupt: 