In [1]:
from dataclasses import dataclass
from datetime import datetime

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import numpy as np 
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
from tqdm.auto import tqdm
from torchfuzzy import FuzzyLayer, DefuzzyLinearLayer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
batch_size = 256
learning_rate = 1e-4
weight_decay = 1e-2
labels_count = 10
num_epochs = 5
fuzzy_dim = 2
fuzzy_rules_count = 10

prefix = "mamdani_mnist"
writer = SummaryWriter(f'runs/mnist/{prefix}_{datetime.now().strftime("%Y%m%d-%H%M%S")}')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [3]:
# inp = torch.rand((10,3))
# df = DefuzzyLinearLayer.from_dimensions(3, 2)
# df.forward(inp),2)

In [4]:
transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Lambda(lambda x: x.view(-1, 28, 28) - 0.5),
])

In [5]:
def get_target(target_label):
    """
    Возвращает вектор целевого значения

    Args:
        target_label (int): Метка класса
    
    Returns:
        tensor (1, 10)
    """
    t = F.one_hot(torch.LongTensor([target_label]), labels_count)
    return t.to(device)

In [6]:
# загружаем обучающую выборку
train_data = datasets.MNIST(
    '~/.pytorch/MNIST_data/', 
    download=True, 
    train=True, 
    transform = transform,
    target_transform = transforms.Lambda(lambda x: get_target(x))
)

In [7]:
# загружаем тестовую выборку
test_data = datasets.MNIST(
    '~/.pytorch/MNIST_data/', 
    download=True, 
    train=False, 
    transform=transform, 
    target_transform = transforms.Lambda(lambda x: get_target(x))
)
len(test_data)

10000

In [8]:
# Создаем итераторы датасетов
train_loader = torch.utils.data.DataLoader(
    train_data, 
    batch_size=batch_size, 
    shuffle=True,
    
)
test_loader = torch.utils.data.DataLoader(
    test_data, 
    batch_size=batch_size, 
    shuffle=False,
)

In [9]:
class Encoder(nn.Module):
    """
    Компонент энкодера
    
    Args:
        fuzzy_dim (int): Размер латентного вектора.
    """
    
    def __init__(self, fuzzy_dim):
        super(Encoder, self).__init__()
                
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=5),
            nn.SiLU(),  
            nn.Conv2d(8, 16, kernel_size=5),
            nn.SiLU(),  
            nn.Conv2d(16, 32, kernel_size=5),
            nn.SiLU(),  
            nn.Conv2d(32, 64, kernel_size=5),
            nn.BatchNorm2d(64),
            nn.SiLU(),  
            nn.Flatten(),
            nn.Linear(9216, 625),
            nn.BatchNorm1d(625),
            nn.SiLU(),  
            nn.Linear(625, fuzzy_dim),
        )
         
    def forward(self, x):
        """
        Выход энкодера
        
        Args:
            x (torch.Tensor): Входной вектор.
        
        Returns:
            encoded input
        """

        ex = self.encoder(x)
        
        return ex

In [10]:
class MamdaniFIS(nn.Module):
    """
    MamdaniFIS
    
    Args:
        fuzzy_dim (int): Размер латентного вектора.
        labels_count (int): Количество выходов классификатора
    """
    def __init__(self, fuzzy_dim, fuzzy_rules_count, labels_count):
        super(MamdaniFIS, self).__init__()

        self.encoder = Encoder(fuzzy_dim)        
        
        self.fuzzy = nn.Sequential(
            FuzzyLayer.from_dimensions(fuzzy_dim, fuzzy_rules_count, trainable=True),
            DefuzzyLinearLayer.from_dimensions(fuzzy_rules_count, labels_count)
        )
        
    def forward(self, x):
        """
        
        
        Args:
            x (torch.Tensor): Входной вектор.
        
        Returns:
            labels
        """

        ex = self.encoder(x)
        labels = self.fuzzy(ex)

        return labels

In [11]:
def compute_loss(target_labels, predicted_labels):
    
    #print(torch.squeeze(target_labels,1))
    #print(predicted_labels)
    ceLoss = nn.CrossEntropyLoss()
    
    loss_fuzzy = ceLoss.forward(predicted_labels, torch.squeeze(target_labels,1).float())

    return loss_fuzzy
    

In [12]:
model = MamdaniFIS(fuzzy_dim, fuzzy_rules_count, labels_count).to(device)

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Number of parameters: {num_params:,}')

model

Number of parameters: 5,830,935


MamdaniFIS(
  (encoder): Encoder(
    (encoder): Sequential(
      (0): Conv2d(1, 8, kernel_size=(5, 5), stride=(1, 1))
      (1): SiLU()
      (2): Conv2d(8, 16, kernel_size=(5, 5), stride=(1, 1))
      (3): SiLU()
      (4): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1))
      (5): SiLU()
      (6): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
      (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): SiLU()
      (9): Flatten(start_dim=1, end_dim=-1)
      (10): Linear(in_features=9216, out_features=625, bias=True)
      (11): BatchNorm1d(625, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (12): SiLU()
      (13): Linear(in_features=625, out_features=2, bias=True)
    )
  )
  (fuzzy): Sequential(
    (0): FuzzyLayer()
    (1): DefuzzyLinearLayer()
  )
)

In [13]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

In [14]:
def train(model, dataloader, optimizer, prev_updates, writer=None):
    model.train()  
    
    for batch_idx, (data, target) in enumerate(tqdm(dataloader)):
        n_upd = prev_updates + batch_idx
        
        data = data.to(device)
        
        optimizer.zero_grad()  
        
        labels = model.forward(data)  
        
        loss = compute_loss(target, labels)
        
        loss.backward()
        
        if n_upd % 100 == 0:
            total_norm = 0.0
            for p in model.parameters():
                if p.grad is not None:
                    param_norm = p.grad.data.norm(2)
                    total_norm += param_norm.item() ** 2
            total_norm = total_norm ** (1. / 2)
        
            print(f'Step {n_upd:,} (N samples: {n_upd*batch_size:,}), Loss: {loss.item():.4f} Grad: {total_norm:.4f}')

            if writer is not None:
                global_step = n_upd
                writer.add_scalar('Loss/Train', loss.item(), global_step)
                writer.add_scalar('GradNorm/Train', total_norm, global_step)
            
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)    
        
        optimizer.step()  
        
    return prev_updates + len(dataloader)

In [15]:
def test(model, dataloader, cur_step, writer=None):
    model.eval() 
    test_loss = 0
    test_accuracy = 0
    
    with torch.no_grad():
        for data, target in tqdm(dataloader, desc='Testing'):
            data = data.to(device)
            
            labels = model.forward(data)  
            
            loss = compute_loss(target, labels)
            
            test_loss += loss.item()
            
            pred_target = np.argmax(labels.cpu().numpy(), axis=1)
            target_labels =  np.argmax(target.cpu().numpy(), axis=1)
            test_accuracy += np.sum(target_labels==pred_target)

    test_loss /= len(dataloader)
    test_accuracy /= len(dataloader)

    print(f'====> Test set loss: {test_loss:.4f} (Accuracy {test_accuracy:.4f})')
    
    if writer is not None:
        writer.add_scalar('Loss/Test', test_loss, global_step=cur_step)
        writer.add_scalar('Fuzzy/Test/Accuracy', test_accuracy, global_step=cur_step)

In [16]:
prev_updates = 0
for epoch in range(num_epochs):
    print(f'Epoch {epoch+1}/{num_epochs}')
    prev_updates = train(model, train_loader, optimizer, prev_updates, writer=writer)
    test(model, test_loader, prev_updates, writer=writer)

Epoch 1/50


  0%|          | 1/235 [00:01<04:26,  1.14s/it]

Step 0 (N samples: 0), Loss: 2.3078 Grad: 0.8045


 31%|███▏      | 74/235 [00:34<00:09, 17.76it/s]