In [None]:
import sys
sys.path.append("../")

from fgi import *
from torch import nn, randn, cat
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torchvision
from lightning.pytorch import Trainer

In [None]:
class Digit:
    def __init__(self, digit, *args, **kwargs):
        self.digit = int(digit)
    
    def __str__(self):
        return f"{self.digit}"

In [None]:
class DigitClassification(NonCodeProblem):
    def __init__(self, _id, *args, **kwargs):
        super().__init__(_id, *args, **kwargs)
        self._represent_unit = RepresentLayer.from_units([ImageRepresent(img_shape=(1, 28, 28), patch_size=7, num_heads=1, phi_dim=128)], output_dim=128)
        self._property = PropertyLayer.from_units([PropertyUnit(phi_dim=128) for _ in range(2)])
        self._choose = CoPropertyLayer.from_units([ChooseOptions(2, options=["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"], property_name="digit", phi_dim=128)])

        self._update_additional_infor()

    def forward(self, x, *args, **kwargs):
        x = self._represent_unit(x)
        x = self._property(x)
        x = self._choose(x)
        return x
    
    def recognize_unknown(self, *args, **kwargs):
        pass

    @property
    def _as_object(self):
        return Digit

In [None]:
class DigitLearner(LightningLearner):
    def __init__(self, problem, *args, **kwargs):
        super().__init__(problem, *args, **kwargs)
        self.loss_infor = ("digit", nn.CrossEntropyLoss())
    
    def _aggerate_loss(self, y_hat, y, *args, **kwargs):
        tmp = {  }
        tmp["digit"] = self.loss_infor["digit"](y_hat[0], y)
        tmp["total_loss"] = tmp["digit"]
        return tmp
    
    def training_step(self, batch, batch_idx, *args, **kwargs):
        return super().training_step(batch, batch_idx, on_step=True, *args, **kwargs)
    
    def validation_step(self, batch, batch_idx, *args, **kwargs):
        return super().validation_step(batch, batch_idx, on_step=True ,*args, **kwargs)
    
    def configure_optimizers(self):
        optimizer = optim.

In [None]:
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=32, shuffle=True, num_workers=7, persistent_workers=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=32, shuffle=True, num_workers=7, persistent_workers=True)

In [None]:
solver = DigitClassification(_id="digit_classification")
learner = DigitLearner(solver)
for _id in learner.learnable.keys():
    learner.learnable = (_id, True)
learner.compile(optim.Adam, device="cuda:0", lr=0.01)

In [None]:
trainer = Trainer(accelerator="auto", min_epochs=2, max_epochs=5)
trainer.fit(learner, train_loader, train_loader)

In [None]:
writer = SummaryWriter(log_dir="experiment/1")
writer.add_graph(solver, randn(1, 1, 28, 28))
writer.close()