In [1]:
import sys
sys.path.append("../")

from fgi import *
from torch import jit
from torch import nn, randn
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torchvision
from lightning.pytorch import Trainer
from lightning import LightningModule

In [2]:
from torch.cuda import is_available
is_available()

False

In [3]:
class Digit:
    def __init__(self, digit ,*args, **kwargs):
        self._digit = int(digit)
    
    def __repr__(self):
        return f"Digit({self._digit})"
    
    def __str__(self):
        return self.__repr__()

In [4]:
class DigitProblem(NonCodeProblem):
    """
    Bộ phân loại chữ số
    """
    def __init__(self, _id, *args, **kwargs):
        super().__init__(_id, *args, **kwargs)
        self._represent = RepresentLayer.from_units([
            ImageRepresent(img_shape=(1, 28, 28), patch_size=4, num_heads=1, phi_dim=128),
            ImageRepresent(img_shape=(1, 28, 28), patch_size=4, num_heads=1, phi_dim=128),
            # EdgeRepresent(img_shape=(1, 28, 28), patch_size=4, num_heads=2, phi_dim=128)
        ], output_dim=128)
        self._co_represent = CoRepresentLayer.from_units(
            [ CoRepresentUnit(2, phi_dim=128) for _ in range(2) ]
        )
        self._property = PropertyLayer.from_units(
            [ PropertyUnit(phi_dim=128) for _ in range(4) ]
        )
        self._co_property = CoPropertyLayer.from_units(
            [ ChooseOptions(4, options=["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"], property_name="digit", phi_dim=128) ]
        )

        self._update_additional_infor()
    
    def forward(self, x, *args, **kwargs):
        x = self._represent(x)
        x = self._co_represent(x)
        x = self._property(x)
        x = self._co_property(x)
        return x
    
    def recognize_unknown(self, *args, **kwargs):
        pass

    @property
    def _as_object(self):
        return Digit
    
    def as_instance(self, x, skip_inference : bool = False, *args, **kwargs):
        if not skip_inference:
            x = self.forward(x)
        data = self._co_property.intepret(x)
        data.update(**kwargs)
        return self._as_object(**data)


In [5]:
solver = DigitProblem(_id="digit_problem")

In [6]:
class DigitLearner(LightningLearner):
    def __init__(self, problem, *args, **kwargs):
        super().__init__(problem, *args, **kwargs)
        self.loss_infor = ("digit", nn.CrossEntropyLoss())
    
    def _aggerate_loss(self, y_hat, y, *args, **kwargs):
        total_loss = 0.
        tmp = {  }
        for property in y_hat.keys():
            l = self._loss_infor[property](y_hat[property], y)
            total_loss += l
            tmp[f"loss_{property}"] = l
        tmp["total_loss"] = total_loss
        return tmp
    
    def training_step(self, batch, batch_idx, *args, **kwargs):
        return super().training_step(batch, batch_idx, on_step=True, *args, **kwargs)
    
    def validation_step(self, batch, batch_idx, *args, **kwargs):
        return super().validation_step(batch, batch_idx, on_step=True ,*args, **kwargs)

In [7]:
learner = DigitLearner(solver)
for _id in learner.learnable.keys():
    learner.learnable = (_id, True)
learner.compile(optim.SGD, device="cuda:0", lr=0.001)
print("Tổng số tham số: ", learner.total_parameters(solver))
print("Tổng số tham số huấn luyện: ", learner.total_learnable_parameters(solver))

Tổng số tham số:  334236
Tổng số tham số huấn luyện:  334236


In [8]:
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=32, shuffle=True, num_workers=7, persistent_workers=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=32, shuffle=True, num_workers=7, persistent_workers=True)

In [9]:
#trainer = Trainer(accelerator="auto", min_epochs=3, max_epochs=10)
#trainer.fit(learner, train_loader, train_loader)

In [10]:
writer = SummaryWriter("runs/experiment")
writer.add_graph(solver, randn((1, 1, 28, 28)), verbose=True)
writer.close()

graph(%self.1 : __torch__.DigitProblem,
      %x.1 : Float(1, 1, 28, 28, strides=[784, 784, 28, 1], requires_grad=0, device=cpu)):
  %_co_property : __torch__.fgi.layer.CoPropertyLayer.CoPropertyLayer = prim::GetAttr[name="_co_property"](%self.1)
  %_property : __torch__.fgi.layer.PropertyLayer.PropertyLayer = prim::GetAttr[name="_property"](%self.1)
  %_co_represent : __torch__.fgi.layer.CoRepresentLayer.CoRepresentLayer = prim::GetAttr[name="_co_represent"](%self.1)
  %_represent : __torch__.fgi.layer.RepresentLayer.RepresentLayer = prim::GetAttr[name="_represent"](%self.1)
  %1245 : int = prim::Constant[value=-1](), scope: __module._represent/__module._represent._units.0 # e:\simulations\implementations\notebooks\..\fgi\problem\vision\represent.py:55:0
  %1246 : int = prim::Constant[value=4](), scope: __module._represent/__module._represent._units.0/__module._represent._units.0._patch_embedding/__module._represent._units.0._patch_embedding.0 # e:\simulations\implementations\env\Lib\