In [2]:
import os 
import pickle
import random
import shutil
import zipfile
import numpy as np
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torchvision
from torchvision import transforms, models

from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split


# clear_output()
print('Setup complete. Using torch %s %s' % (torch.__version__, torch.cuda.get_device_properties(0) if torch.cuda.is_available() else 'CPU'))

Setup complete. Using torch 1.13.0+cu116 CPU


## Загрузка данных

Данные  - фотографии документов, которые выравнены верно 

In [3]:
from google.colab import drive

In [4]:
drive.mount('/content/gdrive/')

Mounted at /content/gdrive/


In [5]:
img_zip = '/content/gdrive/MyDrive/FlipNet/img.zip'
zipfile_img = zipfile.ZipFile(img_zip, 'r')
zipfile_img.extractall()

## Генерация классов

создаем копиb каждой фотографии переварачивая ее на 90 градусов

In [6]:
! mv img flip0

In [7]:
!mkdir flip90 flip180 flip270

In [8]:
def rotate_img(list_imgs, rot):
    for img in list_imgs:
        if len(img) > 10:
            im = Image.open(os.path.join('flip0', img))
            im_rotate = im.rotate(rot, expand=True)
            im_rotate.save(os.path.join(f"flip{rot}", img))
            im.close()
 
images = [x for x in os.listdir('flip0')]
rotate_img(images, 90)
rotate_img(images, 180)
rotate_img(images, 270)

## Разделим данные на Train и Test

In [9]:
all_images = [[os.path.join('flip0', x), os.path.join('flip90', x), os.path.join('flip180', x), os.path.join('flip270', x)] for x in os.listdir('flip0')]
all_images = np.concatenate(np.array(all_images))
train_images, test_images = train_test_split(all_images, test_size=0.3, random_state=42)

In [10]:
!mkdir images images/train images/train/flip0 images/train/flip90 images/train/flip180 images/train/flip270
!mkdir images/test images/test/flip0 images/test/flip90 images/test/flip180 images/test/flip270

In [11]:
#Utility function to move images 
def move_files_to_folder(list_of_files, destination_folder):
    for f in list_of_files:
        try:
            if f.find('.DS_Store') != -1:
                continue
            prefix = f.split('/', 1)[0]
            shutil.move(f, os.path.join(destination_folder, prefix))
        except:
            print(f)
            assert False

# перетащим файлы по нужным папкам
move_files_to_folder(train_images, 'images/train')
move_files_to_folder(test_images, 'images/test/')

In [12]:
! rm -rf flip0 flip90 flip180 flip270  img

In [13]:
len(os.listdir('images/train/flip0')), len(os.listdir('images/train/flip90')), len(os.listdir('images/train/flip180')), len(os.listdir('images/train/flip270'))

(282, 281, 277, 277)

## Подготовка данных 

разобьем данные на бачи

In [14]:
def augmentation(train_dir = 'train'):
    mas = [transforms.RandomVerticalFlip(p=1), transforms.RandomHorizontalFlip(p=1),
    transforms.ColorJitter(brightness=.5, hue=.3), transforms.RandomRotation(degrees=(0, 180)),
    transforms.RandomAffine(degrees=(30, 70)), transforms.RandomEqualize(p=1)]
    ten = transforms.ToTensor()
    com_mas = []
    for i in range(len(mas)):
        com_mas.append(transforms.Compose([mas[i], ten]))
        for j in range(i+1, len(mas)):
            com_mas.append(transforms.Compose([mas[i], mas[j], ten]))
            for l in range(j+1, len(mas)):
                if l > j+2:
                    break
                com_mas.append(transforms.Compose([mas[i], mas[j], mas[l], ten]))
    dataset_mas = []
    for com in com_mas:
        dataset_mas.append(torchvision.datasets.ImageFolder(train_dir, com))
        
    return torch.utils.data.ConcatDataset(dataset_mas)


In [None]:
# train_transforms = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.ColorJitter(brightness=.5, hue=.3),
#     transforms.ToTensor(),
#     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# ])

# val_transforms = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.ToTensor(),
#     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# ])
# train_dir = 'images/train'
# val_dir = 'images/test'

# train_dataset = torchvision.datasets.ImageFolder(train_dir, train_transforms)
# val_dataset = torchvision.datasets.ImageFolder(val_dir, val_transforms)

# batch_size = 8
# train_dataloader = torch.utils.data.DataLoader(
#     train_dataset, batch_size=batch_size, shuffle=True, num_workers=batch_size)
# val_dataloader = torch.utils.data.DataLoader(
#     val_dataset, batch_size=batch_size, shuffle=False, num_workers=batch_size)



In [None]:
# train_transforms = torch.nn.Sequential(
#     transforms.Resize((224, 224)),
#     transforms.ColorJitter(brightness=.5, hue=.3),
#     transforms.FiveCrop(150),
#     transforms.ToTensor(),
#     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# )

# val_transforms = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.ToTensor(),
#     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# ])
# train_dir = 'images/train'
# val_dir = 'images/test'

# train_dataset = torchvision.datasets.ImageFolder(train_dir, torch.jit.script(train_transforms))
# val_dataset = torchvision.datasets.ImageFolder(val_dir, val_transforms)

# batch_size = 8
# train_dataloader = torch.utils.data.DataLoader(
#     train_dataset, batch_size=batch_size, shuffle=True, num_workers=batch_size)
# val_dataloader = torch.utils.data.DataLoader(
#     val_dataset, batch_size=batch_size, shuffle=False, num_workers=batch_size)

TypeError: ignored

 #### Посмотрим как теперь выглядят наши фотографии

In [None]:
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
class_names = ['flop0', 'flop90', 'flop180', 'flop270']
def show_input(input_tensor, title=''):
    image = input_tensor.permute(1, 2, 0).numpy()
    image = std * image + mean
    plt.imshow(image.clip(0, 1))
    plt.title(title)
    plt.show()
    plt.pause(0.001)

X_batch, y_batch = next(iter(train_dataloader))

for x_item, y_item in zip(X_batch, y_batch):
    show_input(x_item, title=class_names[y_item])

RuntimeError: ignored

## Объявление модели

In [None]:
model = models.resnet50(pretrained=True)

"""отключить рассчет градиента для всех слоев сети
сеть хорошо предобучена и мы не хотим, что бы веса менялись во время обучения
"""
for param in model.parameters():
    param.requires_grad = False

"""меняем последний слой, вместо 1000 классов, как в оригинале, у нас будет 4
"""
model.fc = torch.nn.Linear(model.fc.in_features, 4)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
loss = torch.nn.CrossEntropyLoss() # функция активации 
optimizer = torch.optim.Adam(model.parameters(), lr=1.0e-3) # метод оптимизации
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1) # планировщик

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


  0%|          | 0.00/97.8M [00:00<?, ?B/s]

## Тренировка модели

In [None]:
def train_model(model, loss, optimizer, scheduler, num_epochs):
    for epochs in range(num_epochs):
        print(f'Epoch {epochs + 1} / {num_epochs}')
        for phase in ['train', 'val']:
            if phase == 'train':
                dataloader = train_dataloader
                scheduler.step()
                model.train()
            else:
                dataloader = val_dataloader
                model.eval()
        true_ans, total, all_loss = 0, 0, 0
        for inputs, label in tqdm(dataloader):
            inputs, label = inputs.to(device), label.to(device)
            optimizer.zero_grad() # обнуляем градиент, что бы он не накапливался
            with torch.set_grad_enabled(phase == 'train'):
                preds = model(inputs)
                loss_value = loss(preds, label)
                preds_class = preds.argmax(dim=1)
                if phase == 'train':
                    loss_value.backward()
                    optimizer.step()
            true_ans += (preds_class == label).sum().item()
            total += label.size(0)
            all_loss += loss_value.item()

        print(f"{phase} accuracy of the network {100 * true_ans / total}, Loss {all_loss}")
    return model

In [None]:
flipmodel = train_model(model, loss, optimizer, scheduler, num_epochs=50);

Epoch 1 / 50


100%|██████████| 60/60 [00:08<00:00,  7.37it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 2 / 50


100%|██████████| 60/60 [00:08<00:00,  7.31it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 3 / 50


100%|██████████| 60/60 [00:08<00:00,  7.41it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 4 / 50


100%|██████████| 60/60 [00:08<00:00,  7.46it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 5 / 50


100%|██████████| 60/60 [00:08<00:00,  7.46it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 6 / 50


100%|██████████| 60/60 [00:08<00:00,  7.44it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 7 / 50


100%|██████████| 60/60 [00:08<00:00,  7.44it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 8 / 50


100%|██████████| 60/60 [00:07<00:00,  7.51it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 9 / 50


100%|██████████| 60/60 [00:08<00:00,  7.30it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 10 / 50


100%|██████████| 60/60 [00:08<00:00,  7.20it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 11 / 50


100%|██████████| 60/60 [00:08<00:00,  7.36it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 12 / 50


100%|██████████| 60/60 [00:08<00:00,  7.47it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 13 / 50


100%|██████████| 60/60 [00:08<00:00,  7.46it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 14 / 50


100%|██████████| 60/60 [00:08<00:00,  7.45it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 15 / 50


100%|██████████| 60/60 [00:08<00:00,  7.42it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 16 / 50


100%|██████████| 60/60 [00:08<00:00,  7.43it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 17 / 50


100%|██████████| 60/60 [00:08<00:00,  7.29it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 18 / 50


100%|██████████| 60/60 [00:08<00:00,  7.41it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 19 / 50


100%|██████████| 60/60 [00:08<00:00,  7.34it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 20 / 50


100%|██████████| 60/60 [00:08<00:00,  7.31it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 21 / 50


100%|██████████| 60/60 [00:08<00:00,  7.34it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 22 / 50


100%|██████████| 60/60 [00:08<00:00,  7.40it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 23 / 50


100%|██████████| 60/60 [00:08<00:00,  7.40it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 24 / 50


100%|██████████| 60/60 [00:08<00:00,  7.34it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 25 / 50


100%|██████████| 60/60 [00:08<00:00,  7.30it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 26 / 50


100%|██████████| 60/60 [00:08<00:00,  7.40it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 27 / 50


100%|██████████| 60/60 [00:08<00:00,  7.39it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 28 / 50


100%|██████████| 60/60 [00:08<00:00,  7.38it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 29 / 50


100%|██████████| 60/60 [00:08<00:00,  7.34it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 30 / 50


100%|██████████| 60/60 [00:08<00:00,  7.41it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 31 / 50


100%|██████████| 60/60 [00:08<00:00,  7.27it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 32 / 50


100%|██████████| 60/60 [00:08<00:00,  7.26it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 33 / 50


100%|██████████| 60/60 [00:08<00:00,  7.37it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 34 / 50


100%|██████████| 60/60 [00:08<00:00,  7.44it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 35 / 50


100%|██████████| 60/60 [00:08<00:00,  7.46it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 36 / 50


100%|██████████| 60/60 [00:08<00:00,  7.34it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 37 / 50


100%|██████████| 60/60 [00:08<00:00,  7.47it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 38 / 50


100%|██████████| 60/60 [00:08<00:00,  7.40it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 39 / 50


100%|██████████| 60/60 [00:08<00:00,  7.39it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 40 / 50


100%|██████████| 60/60 [00:08<00:00,  7.30it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 41 / 50


100%|██████████| 60/60 [00:08<00:00,  7.35it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 42 / 50


100%|██████████| 60/60 [00:08<00:00,  7.48it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 43 / 50


100%|██████████| 60/60 [00:08<00:00,  7.46it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 44 / 50


100%|██████████| 60/60 [00:08<00:00,  7.49it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 45 / 50


100%|██████████| 60/60 [00:08<00:00,  7.45it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 46 / 50


100%|██████████| 60/60 [00:08<00:00,  7.46it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 47 / 50


100%|██████████| 60/60 [00:08<00:00,  7.30it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 48 / 50


100%|██████████| 60/60 [00:08<00:00,  7.32it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 49 / 50


100%|██████████| 60/60 [00:08<00:00,  7.39it/s]


val accuracy of the network 22.175732217573223, Loss 84.87645947933197
Epoch 50 / 50


100%|██████████| 60/60 [00:08<00:00,  7.33it/s]

val accuracy of the network 22.175732217573223, Loss 84.87645947933197





## Сохранение модели

In [None]:
with open('model.pickle', 'wb') as f:
    pickle.dump(flipmodel, f)

with open('model.pickle', 'rb') as f:
    flipmodel = pickle.load(f)

In [None]:
%cp model.pickle /content/gdrive/My\ Drive/FlipNet/

## Тестирование 

In [None]:
def predict(model):
    model.eval() ## фиксируем модельку 
    test_predictions, true_predictions = [], []
    for inputs, labels in tqdm(val_dataloader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        with torch.set_grad_enabled(False):
            preds = model(inputs)
        test_predictions.append(
            np.argmax(torch.nn.functional.softmax(preds, dim=1).data.cpu().numpy(), axis=1))
        true_predictions.append(labels.data.cpu().numpy())
    return np.concatenate(true_predictions), np.concatenate(test_predictions)

In [None]:
true_predict, my_predict = predict(flipmodel)

100%|██████████| 60/60 [00:08<00:00,  7.32it/s]


In [None]:
print(classification_report(true_predict, my_predict))

              precision    recall  f1-score   support

           0       1.00      0.03      0.05       117
           1       0.14      0.01      0.02       121
           2       0.21      0.22      0.22       122
           3       0.22      0.64      0.33       118

    accuracy                           0.22       478
   macro avg       0.39      0.22      0.15       478
weighted avg       0.39      0.22      0.15       478



In [None]:
true_predict

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,

In [None]:
my_predict

array([3, 2, 3, 3, 3, 3, 2, 3, 2, 2, 2, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 2,
       2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 0, 3, 3, 2, 0, 2, 3, 3, 0, 3, 3, 3,
       3, 3, 3, 3, 2, 2, 3, 3, 3, 2, 2, 3, 3, 3, 3, 2, 3, 3, 3, 3, 2, 3,
       2, 3, 3, 3, 3, 2, 3, 3, 3, 3, 2, 2, 3, 3, 3, 3, 3, 2, 3, 3, 3, 2,
       3, 3, 3, 3, 2, 3, 3, 3, 2, 3, 3, 3, 2, 2, 2, 2, 3, 3, 3, 3, 3, 2,
       2, 3, 3, 2, 3, 3, 3, 2, 3, 3, 2, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3,
       3, 3, 3, 3, 3, 3, 2, 2, 3, 2, 3, 2, 3, 3, 3, 3, 3, 2, 2, 3, 3, 3,
       2, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3, 3, 3, 3, 2, 3, 2, 3, 3, 3,
       3, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 2, 3, 2, 3, 3, 3, 3, 3, 3, 3,
       3, 3, 2, 3, 3, 3, 3, 3, 2, 3, 3, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3,
       2, 3, 2, 3, 3, 3, 3, 3, 3, 2, 2, 3, 3, 3, 3, 3, 3, 2, 3, 3, 2, 3,
       2, 2, 2, 2, 3, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 2, 3, 2, 3, 3, 1, 3,
       3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3,
       1, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3, 3, 3,