In [None]:
!pip install multidict -q #efficientnet_pytorch

In [None]:
import pandas as pd
import numpy as np
import os

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

from PIL import Image

import albumentations as A
from albumentations.pytorch import ToTensorV2

# from efficientnet_pytorch import EfficientNet
from torch.utils.data import Dataset, DataLoader, Sampler
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision.models as models

from numba import jit

from sklearn.model_selection import train_test_split

from tqdm.auto import tqdm
from multidict import MultiDict

import cv2

In [None]:
DEFAULT_RANDOM_SEED = 42
import random
import numpy as np


def set_all_seeds(seed=DEFAULT_RANDOM_SEED):

    # python's seeds
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)

    # torch's seeds
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_all_seeds(seed=DEFAULT_RANDOM_SEED)

In [None]:
# для получения label для каждого файла
class_dict = {
    'GP': 0, 'G': 1, 'M': 2, 'T': 3, 'clear': 4 # выкинуть один класс - clear
}

# class_dict = {
#     '-20':0, '-25':1, '-30':2, '-35':3
# }

def get_class_from_path(path, class_dict=class_dict):
    for key in class_dict:
        if key in path:
            return class_dict[key]

In [None]:
root_dir = '/kaggle/input/lozzzz/all_images'
image_files = [f for f in os.listdir(root_dir) if f.endswith('.jpg')]

new_img_files = []
labels = []
for elem in image_files:
    for key in class_dict:
        if key in elem:
            labels.append(class_dict[key])
            new_img_files.append(elem)
            break
        else:
            continue

In [None]:
train_paths, valid_paths = train_test_split(new_img_files, random_state=42, shuffle=True, train_size=0.7, stratify=labels)

In [None]:
labels_train = []
for elem in train_paths:
    for key in class_dict:
        if key in elem:
            labels_train.append(class_dict[key])
            break

In [None]:
from collections import Counter
Counter(labels_train)

class_counts = [106, 104, 89, 99, 59]
class_weights = 1. / torch.tensor(class_counts, dtype=torch.float)
class_weights /= class_weights.sum()

In [None]:
@jit(nopython=True)
def cut_fragments(image, mode, n, size):
    height, width = image.shape[:2]
    fragments = []
    if mode == 'central':
        for i in range(n):
            for j in range(n):
                left = (width / n) * i
                upper = (height / n) * j
                right = left + size
                lower = upper + size
                fragment = image[int(upper):int(lower), int(left):int(right)]
                fragments.append(fragment)
    elif mode == 'random':
        for _ in range(n):
            left = random.randint(0, width - size)
            upper = random.randint(0, height - size)
            right = left + size
            lower = upper + size
            fragment = image[int(upper):int(lower), int(left):int(right)]
            fragments.append(fragment)
    return fragments

In [None]:
# обрезаем хвосты и заменяем их на 0 и 1
def cut_percentiles(image):
    q1 = np.percentile(image, 1)
    q99 = np.percentile(image, 99)

    image[image < q1] = 0
    image[image > q99] = 1
    return image

def bilateral_filter(image):
    return cv2.bilateralFilter(image, 9, 75, 75)

# применение bilateral filter
def apply_bilateral_filter_to_normalized(image):
    image_8bit = (image * 255).astype(np.uint8)
    return bilateral_filter(image_8bit) / 255.0

def median_filter(image):
    return cv2.medianBlur(image, 5)

def apply_median_filter_to_normalized(image):
    image_8bit = (image * 255).astype(np.uint8)
    return median_filter(image_8bit) / 255.0

In [None]:
class LozDataset(Dataset):
    def __init__(self, root_dir, image_files, mode: str = 'central', n: int = 3, size: int = 224, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_files = [os.path.join(self.root_dir, image_file) for image_file in image_files]
        self.mode = mode
        self.n = n
        self.size = size
        
    def __len__(self):
        if self.mode == 'central':  
            return len(self.image_files) * self.n * self.n
        else:
            return len(self.image_files) * self.n

    def __getitem__(self, idx):
        if self.mode:
            fragments_per_image = self.n * self.n
        else:
            fragments_per_image = self.n

        image_idx = idx // fragments_per_image
        fragment_idx = idx % fragments_per_image
        
        img_name = self.image_files[image_idx]
        image = Image.open(img_name).convert('L')
        image_np = np.array(image)
        
        fragments = cut_fragments(image=image_np, mode=self.mode, n=self.n, size=self.size)
        fragment = fragments[fragment_idx]
        
        # нормализуем фрагмент
        max_value = np.max(fragment)
        fragment = fragment / max_value
#         fragment = cut_percentiles(fragment)
#         fragment = apply_bilateral_filter_to_normalized(fragment)
        
        if self.transform:
            fragment = self.transform(image=fragment)['image']
        
        label = get_class_from_path(img_name)
        image_t = torch.tensor(fragment, dtype=torch.float32)

        return image_t, img_name, label

In [None]:
# соль и перец
class SaltAndPepper(A.ImageOnlyTransform):
    def __init__(self, p=1., salt_ratio=0.5, amount=0.0008, always_apply=True):
        super().__init__(always_apply, p)
        self.salt_ratio = salt_ratio
        self.amount = amount

    def apply(self, image, **params):
        image_copy = np.copy(image)  # создание копии изображения

        num_salt = np.ceil(self.amount * image.size * self.salt_ratio)
        coords_salt = [np.random.randint(0, i - 1, int(num_salt)) for i in image_copy.shape]
        image_copy[coords_salt[0], coords_salt[1]] = 1

        num_pepper = np.ceil(self.amount * image.size * (1.0 - self.salt_ratio))
        coords_pepper = [np.random.randint(0, i - 1, int(num_pepper)) for i in image_copy.shape]
        image_copy[coords_pepper[0], coords_pepper[1]] = 0

        return image_copy

In [None]:
train_transform = A.Compose([
    A.HorizontalFlip(p=.3),
#     A.RandomBrightnessContrast(p=1, contrast_limit=(.2), brightness_by_max=True, brightness_limit=(.2)),
    A.Rotate(limit=30, p=.3),
#     A.GaussianBlur(p=1, blur_limit=(1,3)),
#     A.CoarseDropout(max_holes=6, p=1., fill_value=200, max_height=3, max_width=3),
#     A.CoarseDropout(max_holes=6, p=1., fill_value=0, max_height=3, max_width=3),
#     A.GaussNoise(var_limit=(10.0), p=1), #белый шум
    SaltAndPepper(salt_ratio=0.4),
#     A.ElasticTransform(alpha=2, sigma=20, alpha_affine=10, p=.4),
])


In [None]:
root_dir = '/kaggle/input/lozzzz/all_images'
train_dataset = LozDataset(root_dir, train_paths, transform=train_transform, size=240)
valid_dataset = LozDataset(root_dir, valid_paths, size=240)

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2, drop_last=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=16, shuffle=False, num_workers=2, drop_last=True)

In [None]:
NUM_CLASSES = len(class_dict)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# mobilenet_v2_model = models.mobilenet_v2(pretrained=True)
# mobilenet_v2_model.features[0][0] = nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
# mobilenet_v2_model.classifier[1] = nn.Linear(mobilenet_v2_model.last_channel, NUM_CLASSES)
# mobilenet_v2_model = mobilenet_v2_model.to(device)
mobilenet_v2_model = models.efficientnet_b1(pretrained=True)
num_classes = len(class_dict) 
first_conv_layer = mobilenet_v2_model.features[0][0]
mobilenet_v2_model.features[0][0] = torch.nn.Conv2d(1, first_conv_layer.out_channels, 
                                      kernel_size=first_conv_layer.kernel_size, 
                                      stride=first_conv_layer.stride, 
                                      padding=first_conv_layer.padding, bias=False)
mobilenet_v2_model.classifier[1] = torch.nn.Linear(mobilenet_v2_model.classifier[1].in_features, num_classes)
mobilenet_v2_model = mobilenet_v2_model.to(device)

if torch.cuda.device_count() > 1:
    mobilenet_v2_model = torch.nn.DataParallel(mobilenet_v2_model)

In [None]:
criterion = nn.CrossEntropyLoss(weight=class_weights.to(device))
# criterion = FocalLoss(alpha=.8)

optimizer = torch.optim.AdamW(mobilenet_v2_model.parameters())
# exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.6)

milestones = [12, 15, 26]
gamma = 0.3
exp_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=gamma)

In [None]:
import wandb
wandb.init(
    # set the wandb project where this run will be logged
    project="first_expirement",
    
    # track hyperparameters and run metadata
    config={
        "architecture": "efficientnet_b1",
        "dataset": "lozz",
        "epochs": 30,
        "fragments": 9,
        "central": True,
        "batch_size": 16,
        "classes": 5,
        "size": 240,
        "preprocessing": "normalizing",
        "augmentations": """
                        A.HorizontalFlip(p=.3),
                    #     A.RandomBrightnessContrast(p=1, contrast_limit=(.2), brightness_by_max=True, brightness_limit=(.2)),
                        A.Rotate(limit=30, p=.3),
                    #     A.GaussianBlur(p=1, blur_limit=(1,3)),
                    #     A.CoarseDropout(max_holes=6, p=1., fill_value=200, max_height=3, max_width=3),
                    #     A.CoarseDropout(max_holes=6, p=1., fill_value=0, max_height=3, max_width=3),
                    #     A.GaussNoise(var_limit=(10.0), p=1), #белый шум
                        SaltAndPepper(salt_ratio=.4),
                    #    A.ElasticTransform(alpha=2, sigma=20, alpha_affine=10, p=.4),
                        """,
        "optimizer": 'torch.optim.AdamW(mobilenet_v2_model.parameters())',
        "criterion": 'nn.CrossEntropyLoss(weight=class_weights.to(device))',
        "sheduler": 'torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=gamma)'
    }
)

In [None]:
train_loss = []
train_acc = []
test_loss = []
test_acc = []

def train_and_validate(epoch):
                                                ### train
    print(f'EPOCH: {epoch + 1}')
    running_loss = 0.0
    running_acc = 0.0
    mobilenet_v2_model.train()
    for batch_idx, (data, name, target) in tqdm(enumerate(train_dataloader)):
        target = target.type(torch.LongTensor)
        
        data = data.unsqueeze(1)
        data, target = data.to(device).float(), target.to(device)
        optimizer.zero_grad()
        outputs = mobilenet_v2_model(data)
    
        loss = criterion(outputs, target)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        preds = outputs.argmax(dim=1)
        running_acc += (preds == target).float().mean().item()
        
        if batch_idx % 100 == 0:
            random_img = data.cpu().numpy()[np.random.randint(data.size(0))][0]
            plt.imshow(random_img, cmap='gray')
            plt.title("Random Image from Batch")
            plt.axis('off')
            plt.show()
    
    train_loss.append(running_loss / len(train_dataloader))
    train_acc.append(running_acc / len(train_dataloader))
    
    print(f"Epoch {epoch+1}, Train Loss: {train_loss[-1]:.3f}, Train Acc: {train_acc[-1]:.3f}, ")
    exp_lr_scheduler.step()
    
                                                ### validate

    mobilenet_v2_model.eval()
    all_preds = [] 
    all_targets = [] 
    with torch.no_grad():
        running_acc = 0.0
        for batch_idx, (data, name, target) in enumerate(valid_dataloader):
            data = data.unsqueeze(1)
            data, target = data.to(device), target.to(device)
            data = data.float()

            outputs = mobilenet_v2_model(data)

            loss = criterion(outputs, target)
            running_loss += loss.item()
            preds = outputs.argmax(dim=1)
            running_acc += (preds == target).float().mean().item()

            all_targets.extend(target.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            

    test_loss.append(running_loss / len(valid_dataloader))
    test_acc.append(running_acc / len(valid_dataloader))
    
    print(f"Epoch {epoch+1}, Valid Loss: {test_loss[-1]:.3f}, Valid Acc: {test_acc[-1]:.3f}, ")
    
    cm = confusion_matrix(all_targets, all_preds)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm)
    disp.plot()
    plt.show()
    
    wandb.log({"train_acc": train_acc[-1], "train_loss": train_loss[-1], "valid_acc": test_acc[-1], "valid_loss": test_loss[-1]})
    return train_loss, train_acc, test_loss, test_acc

In [None]:
best_loss = float('inf')
epochs_without_improvement = 0
early_stopping_threshold = 5

for epoch in range(30):
    train_loss, train_acc, test_loss, test_acc = train_and_validate(epoch)
    
    if test_loss[-1] < best_loss:
        best_loss = test_loss[-1]
        epochs_without_improvement = 0
    else:
        epochs_without_improvement += 1
        
    if epochs_without_improvement >= early_stopping_threshold:
        print("Early stopping triggered after {} epochs without improvement.".format(epochs_without_improvement))
        break

In [None]:
wandb.finish()

## Оценка результатов для полных картинок

In [None]:
def get_predictions():
    """
    возвращает MultiDict, в котором каждому названию картинки соответствует несколько значений
    это предсказания для фрагментов данного изображения
    для каждого изображения будет строчек столько, на сколько фрагментов разбиваем это изображение
    """
    fragments = MultiDict()
    mobilenet_v2_model.eval()
    with torch.no_grad():
        for batch_idx, (data, name, target) in enumerate(valid_dataloader):
            data = data.unsqueeze(1)
            data, target = data.to(device), target.to(device)
            data = data.float()
            outputs = mobilenet_v2_model.to(device)(data)
            preds = outputs.argmax(dim=1)
            
            for n, pred in zip(name, preds):
                fragments.add(n, pred.item())
    return fragments
fragments = get_predictions()

In [None]:
from collections import Counter

def most_common_class_per_key(multidict):
    """
    Получает MultiDict на вход и подсчитывает для одного изображения самый частый предсказанный класс
    Выдает словарь, с названием изображения и самым частым классом
    """
    result = {}
    keys = set(multidict.keys())
    
    for key in keys:
        values = multidict.getall(key)
        count = Counter(values)  
        most_common_class, _ = count.most_common(1)[0]
        result[key] = most_common_class
        
    return result

result = most_common_class_per_key(fragments)

In [None]:
def accuracy_full(result: dict):  
    """
    result - словарь, где каждому пути к изображению сопоставляется самый часто встречаемый класс
    return accuracy - между предсказанными значениями и истинными
    """
    true_val = 0.0
    for key, value in result.items():
        y_true = get_class_from_path(f"'{key}'")
        y_pred = result[key]
        if y_true == y_pred:
            true_val += 1
    accuracy = round(true_val / len(result) * 100, 3)
    return f"Accuracy для полных картинок: {accuracy}%"
accuracy_full(result)