In [2]:
import os
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from PIL import Image
import torch
import torch.utils.data as data
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
import pytorch_lightning as pl
from torchmetrics.functional import accuracy, recall, specificity
import pickle
import argparse

In [2]:
import random

def make_filepath_list():
    train_file_list = []
    valid_file_list = []
    num_sumples = len(os.listdir('../dataset/train/images/off'))
    for top_dir in os.listdir('../dataset/train/images/'):
        file_dir = os.path.join('../dataset/train/images/',top_dir)
        if file_dir == '../dataset/train/images/._.DS_Store' or file_dir == '../dataset/train/images/.DS_Store':
            continue
        file_list = os.listdir(file_dir)
        random.shuffle(file_list)
        file_list = file_list[:num_sumples]

        

        #８割を学習データ、２割を検証データ
        num_data = len(file_list)
        num_split = int(num_data * 0.8)

        train_file_list += [os.path.join('../dataset/train/images/',top_dir,file).replace('\\','/') for file in file_list[:num_split]]
        valid_file_list += [os.path.join('../dataset/train/images/',top_dir,file).replace('\\','/') for file in file_list[num_split:]]

    return train_file_list, valid_file_list

train_file_list, valid_file_list = make_filepath_list()

if '../dataset/train/images/on/._.DS_Store' in train_file_list:
    train_file_list.remove('../dataset/train/images/on/._.DS_Store')
if '../dataset/train/images/on/._.DS_Store' in valid_file_list:
    valid_file_list.remove('../dataset/train/images/on/._.DS_Store')
if '../dataset/train/images/off/._.DS_Store' in train_file_list:
    train_file_list.remove('../dataset/train/images/off/._.DS_Store')
if '../dataset/train/images/off/._.DS_Store' in valid_file_list:
    valid_file_list.remove('../dataset/train/images/off/._.DS_Store')
if '../dataset/train/images/on/.DS_Store' in train_file_list:
    train_file_list.remove('../dataset/train/images/on/.DS_Store')
if '../dataset/train/images/on/.DS_Store' in valid_file_list:
    valid_file_list.remove('../dataset/train/images/on/.DS_Store')
if '../dataset/train/images/off/.DS_Store' in train_file_list:
    train_file_list.remove('../dataset/train/images/off/.DS_Store')
if '../dataset/train/images/off/.DS_Store' in valid_file_list:
    valid_file_list.remove('../dataset/train/images/off/.DS_Store')

print('学習データ数 : ', len(train_file_list))
print(train_file_list[:3])
print('検証データ数 : ', len(valid_file_list))
print(valid_file_list[:3])


学習データ数 :  3956
['../dataset/train/images/on/20201217_160_on_0000000067.jpg', '../dataset/train/images/on/20201217_007_on_0000000307.jpg', '../dataset/train/images/on/20201217_010_on_0000000026.jpg']
検証データ数 :  990
['../dataset/train/images/on/20201217_122_on_0000000039.jpg', '../dataset/train/images/on/20201217_029_on_0000000062.jpg', '../dataset/train/images/on/20201217_162_on_0000000361.jpg']


In [3]:
class ImageTransform(object):
    def __init__(self,resize,mean,std):
        self.data_transform = {
            'train': transforms.Compose([ 
                #データオーグメンテーション
                transforms.RandomHorizontalFlip(),
                #画像をresizexresizeの大きさに統一する
                transforms.Resize((resize,resize)),
                #Tensor型に変換する
                transforms.ToTensor(),
                #色情報の標準化
                transforms.Normalize(mean,std)
            ]),
            'valid': transforms.Compose([
                transforms.Resize((resize,resize)),
                transforms.ToTensor(),
                transforms.Normalize(mean,std)
            ]),
            'test': transforms.Compose([
                transforms.Resize((resize,resize)),
                transforms.ToTensor(),
                transforms.Normalize(mean,std)
            ])
        }
    def __call__(self, img, phase='train'):
        return self.data_transform[phase](img)

resize = 300
mean = (0.5,0.5,0.5)
std = (0.5,0.5,0.5)
transform = ImageTransform(resize,mean,std)

In [4]:
class SurgeryDataset(data.Dataset):
    def __init__(self,file_list,classes,transform=None,phase='test'):
        self.phase = phase
        self.file_list = file_list
        self.transform = transform
        self.classes = classes
        self.phase = phase
    def __len__(self):
        #画像の枚数を返す
        return len(self.file_list)
        
    def __getitem__(self,index):
        #前処理した画像データのTensor形式のデータとラベルを取得

        #指定したindexの画像を読み込む
        img_path = self.file_list[index]
        img = Image.open(img_path)

        #画像の前処理を実施
        img_transformed = self.transform(img,self.phase)

        #画像ラベルをファイル名から抜き出す
        if self.phase == 'train' or self.phase=='valid':
            label = self.file_list[index].split('_')[-2]
        else:
            label = self.file_list[index].split('_')[-1][:-4]
        

        #ラベル名を数値に変換
        label = self.classes.index(label)

        return img_transformed, label

surgery_classes = ['on','off']

#Datasetの作成
train_dataset = SurgeryDataset(
    file_list=train_file_list,classes=surgery_classes,
    transform=ImageTransform(resize,mean,std),
    phase='train'
)
valid_dataset = SurgeryDataset(
    file_list=valid_file_list,classes=surgery_classes,
    transform=ImageTransform(resize,mean,std),
    phase='valid'
)


In [12]:
index = 0

#バッチサイズの指定
batch_size = 16

#DataLoaderを作成
train_dataloader = data.DataLoader(
    train_dataset, batch_size=batch_size,
    num_workers=0,shuffle=True
)

valid_dataloader = data.DataLoader(
    valid_dataset,batch_size=16,num_workers=0,shuffle=True
)

# batch_iterator = iter(train_dataloader)
# inputs, labels = next(batch_iterator)

In [20]:
import timm

In [21]:
modelname = 'resnetv2_50'
FEATS = 2

In [22]:
class Net(pl.LightningModule):
    #ネットワークで使用する層を記述
    def __init__(self,date=None):
        super().__init__()
        self.model = timm.create_model(modelname, pretrained=True, num_classes=2)

    #順伝搬処理を記述
    def forward(self,x):
        x = self.model(x)
        return x
    
    def training_step(self,batch,batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat,y)

        return {'loss':loss, 'y_hat':y_hat, 'y':y,
        'batch_loss':loss.item()*x.size(0)}

    #各エポック終了時の処理を記述
    def training_epoch_end(self, train_step_outputs):
        y_hat = torch.cat([val['y_hat'] for val in 
        train_step_outputs], dim=0)
        y = torch.cat([val['y'] for val in 
        train_step_outputs], dim=0)
        epoch_loss = sum([val['batch_loss'] for val in 
        train_step_outputs]) / y_hat.size(0)

        preds = torch.argmax(y_hat,dim=1)
        acc = accuracy(preds, y)
        self.log('train_loss', epoch_loss, prog_bar=True, on_epoch=True)
        self.log('train_acc', acc, prog_bar=True, on_epoch=True)

        print('-------- Current Epoch {} --------'.format(self.current_epoch + 1))
        print('train Loss: {:.4f} train Acc: {:.4f}'.format(epoch_loss, acc))
        
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        
        return {'y_hat': y_hat, 'y': y, 'batch_loss': loss.item() * x.size(0)}
    
    def validation_epoch_end(self, val_step_outputs):
        # x_hatを一つにまとめる
        y_hat = torch.cat([val['y_hat'] for val in val_step_outputs], dim=0)
        y = torch.cat([val['y'] for val in val_step_outputs], dim=0)
        epoch_loss = sum([val['batch_loss'] for val in val_step_outputs]) / y_hat.size(0)

        preds = torch.argmax(y_hat, dim=1)
        acc = accuracy(preds, y)
        rec = recall(preds,y)
        spec = specificity(preds,y)

        self.log('val_loss', epoch_loss, prog_bar=True, on_epoch=True)
        self.log('val_acc', acc, prog_bar=True, on_epoch=True)

        print('valid Loss: {:.4f} valid Acc: {:.4f} valid Recall: {:.4f} valid Specificity: {:.4f}'.format(epoch_loss, acc, rec, spec))

    def test_step(self, batch, batch_idx):
        x, y = batch

        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        
        return {'y_hat': y_hat, 'y': y, 'batch_loss': loss.item() * x.size(0)}
    
    def test_epoch_end(self, test_step_outputs):
        # x_hatを一つにまとめる
        y_hat = torch.cat([val['y_hat'] for val in test_step_outputs], dim=0)

        y = torch.cat([val['y'] for val in test_step_outputs], dim=0)
        # with open('../dataset/test/'+date+'/results/true.pickle', mode='wb') as f:
        #     pickle.dump(y, f)
        epoch_loss = sum([val['batch_loss'] for val in test_step_outputs]) / y_hat.size(0)

        preds = torch.argmax(y_hat, dim=1)
        with open('../dataset/test/'+date+'/results/preds_bin.pickle', mode='wb') as f:
            pickle.dump(preds, f)
        acc = accuracy(preds, y)

        self.log('test_loss', epoch_loss, prog_bar=True, on_epoch=True)
        self.log('test_acc', acc, prog_bar=True, on_epoch=True)

        print('test Loss: {:.4f} test Acc: {:.4f}'.format(epoch_loss, acc))

    # 最適化手法を記述する
    def configure_optimizers(self):
        optimizer = optim.SGD(self.parameters(), lr=0.01)

        return optimizer



net = Net()

es = pl.callbacks.EarlyStopping(monitor='val_loss')

trainer = pl.Trainer(
    max_epochs=20,
    callbacks=[es],
    gpus = 0,
)


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnetv2_50_a1h-000cdf49.pth" to /Users/taichii/.cache/torch/hub/checkpoints/resnetv2_50_a1h-000cdf49.pth
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name  | Type     | Params
-----------------------------------
0 | model | ResNetV2 | 23.5 M
-----------------------------------
23.5 M    Trainable params
0         Non-trainable params
23.5 M    Total params
94.018    Total estimated model params size (MB)


Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Validation sanity check: 100%|██████████| 2/2 [00:14<00:00,  7.28s/it]valid Loss: 0.6947 valid Acc: 0.4688
Epoch 0:   0%|          | 0/310 [00:00<00:00, 2409.13it/s]            

  rank_zero_warn(


Epoch 0:   1%|          | 2/310 [00:45<1:18:30, 15.29s/it, loss=0.695, v_num=2]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


# train

In [None]:
trainer.fit(
    net,
    train_dataloaders=train_dataloader,
    val_dataloaders=valid_dataloader,
)

torch.save(net.state_dict(),'../models/model.pth')

# test

In [None]:
parser = argparse.ArgumentParser()

parser.add_argument('-date')

args = parser.parse_args()
date = args.date

#テストデータのファイルパス
test_file_list = []
file_dir = '../dataset/test/'+date+'/images/'
file_list = os.listdir(file_dir)
test_file_list += [os.path.join('../dataset/test/'+date+'/images/',file).replace('\\','/') for file in file_list]
test_file_list = sorted(test_file_list)
if '../dataset/test/'+date+'/images/._.DS_Store' in test_file_list:
    test_file_list.remove('../dataset/test/'+date+'/images/._.DS_Store')
if '../dataset/test/'+date+'/images/.DS_Store' in test_file_list:
    test_file_list.remove('../dataset/test/'+date+'/images/.DS_Store')

with open('../dataset/test/'+date+'/results/path_list.pickle', mode='wb') as f:
    pickle.dump(test_file_list, f)

print('テストデータ数 : ', len(test_file_list))
print(test_file_list[:3])

In [None]:
test_dataset = SurgeryDataset(
    file_list=test_file_list,classes=surgery_classes,
    transform=ImageTransform(resize,mean,std),
    phase='test'
)
index = 0

#バッチサイズの指定
batch_size = 32

#DataLoaderを作成

test_dataloader = data.DataLoader(
    test_dataset,batch_size=16,num_workers=16,shuffle=False
)

In [None]:
net.load_state_dict(torch.load('../weights/Learned_model.pt'))

trainer = pl.Trainer(
    max_epochs=20,
    gpus = 1,
)

trainer.test(model=net, dataloaders=test_dataloader )