In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy


In [2]:
mean = np.array([0.5, 0.5, 0.5])
std = np.array([0.25, 0.25, 0.25])

data_transforms = {
    'train': transforms.Compose([
        transforms.Resize(299),
        # transforms.CenterCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ]),
    'val': transforms.Compose([
        transforms.Resize(299),
        # transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ]),
}

In [3]:
import os
import random
import torch
from torchvision import datasets, transforms

data_dir = "D:/P2023/DATA/glomer_cg"

# Create a single merged dataset
merged_dataset = datasets.ImageFolder(data_dir, data_transforms['train'])

# Shuffle the merged dataset randomly
random.seed(42)  # You can choose any random seed for reproducibility
indices = list(range(len(merged_dataset)))
random.shuffle(indices)

# Define the split ratio (e.g., 80% for training and 20% for validation)
split_ratio = 0.8  # You can adjust this ratio as needed

# Calculate the split indices
split_index = int(len(indices) * split_ratio)
train_indices = indices[:split_index]
val_indices = indices[split_index:]

# Create data loaders for training and validation using the split indices
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
val_sampler = torch.utils.data.SubsetRandomSampler(val_indices)

train_loader = torch.utils.data.DataLoader(merged_dataset, batch_size=4, sampler=train_sampler, num_workers=0)
val_loader = torch.utils.data.DataLoader(merged_dataset, batch_size=4, sampler=val_sampler, num_workers=0)

# Example usage:
print(f"Number of training samples: {len(train_indices)}")
print(f"Number of validation samples: {len(val_indices)}")


Number of training samples: 10171
Number of validation samples: 2543


In [4]:

# num_datapoint = len(raw_datasets)

image_datasets = {
    "train": train_loader.dataset,
    "val": val_loader.dataset
}
dataloaders = {
    x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4, shuffle=True, num_workers=0)
    for x in ['train', 'val']
}

In [5]:
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(class_names)
print(device)

['IGA', 'MGN']
cuda:0


In [6]:
def imshow(inp, title):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    plt.title(title)
    plt.show()


# Get a batch of training data
inputs, classes = next(iter(dataloaders['train']))

# Make a grid from batch
out = torchvision.utils.make_grid(inputs)

# imshow(out, title=[class_names[x] for x in classes])

In [7]:
print(classes)

tensor([0, 0, 0, 1])


In [8]:
# complete
from datetime import datetime
log_path ="D:/P2023/LOG/inception_v3.txt"
# print and fprint at the same time
def pprint(output = '\n' , filename = log_path, show_time = False):
    print(output)
    with open(filename, 'a') as f:
        if show_time:
            f.write(datetime.now().strftime("[%Y-%m-%d %H:%M:%S] "))

        f.write(str(output))
        f.write('\n')
pprint("test", show_time=True)

test


In [44]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    best_recallmn = 0.0 # ***
    best_recallig = 0.0 # ***
    
    best_precmn = 0.0 # ***
    best_precig = 0.0 # ***
    
    best_f1mn = 0.0 # ***
    best_f1ig = 0.0 # ***

    for epoch in range(num_epochs):
        pprint('Epoch {}/{}'.format(epoch, num_epochs - 1), show_time=True)
        pprint('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0
            
            tp_positive = 0
            fp_positive = 0
            tn_negative = 0
            fn_negative = 0
            positive_other = 0
            negative_other = 0
            
            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)


                    # backward + optimize only if in training phase
                    if phase == 'train':
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                
                # MN -> positive
                # iga -> negative
                tp_positive += torch.sum((preds == 1) & (labels.data == 1))
                fp_positive += torch.sum((preds == 1) & (labels.data == 0))
                tn_negative += torch.sum((preds == 0) & (labels.data == 0))
                fn_negative += torch.sum((preds == 0) & (labels.data == 1))
                positive_other += torch.sum((preds == 2) & (labels.data == 1))
                negative_other += torch.sum((preds == 2) & (labels.data == 0))
                
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            
            epoch_recallmn = tp_positive.double() / (tp_positive + fn_negative) # ***
            epoch_recallig = tn_negative.double() / (tn_negative + fp_positive) # ***
            
            epoch_precmn = tp_positive.double() / (tp_positive + fp_positive) # ***
            epoch_precig = tn_negative.double() / (tn_negative + fn_negative) # ***
            
            epoch_f1mn = (2 * epoch_recallmn * epoch_precmn) / (epoch_recallmn + epoch_precmn) # ***
            epoch_f1ig = (2 * epoch_recallig * epoch_precig) / (epoch_recallig + epoch_precig) # ***
            
            # print('{} Loss: {:.4f} Acc: {:.4f} Recall_MGN: {:.4f} Recall_IGAN: {:.4f} Precision_MGN: {:.4f} Precision_IGAN: {:.4f} F1_MGN: {:.4f} F1_IGAN: {:.4f}'.format(
            #     phase, epoch_loss, epoch_acc, epoch_recallmn, epoch_recallig, epoch_precmn, epoch_precig, epoch_f1mn, epoch_f1ig)) # ***
            pprint('{} Loss: {:.4f} Accuracy: {:.4f} \n     Recall Precision F1_score OTHER\n MGN: {:.4f} {:.4f} {:.4f} {}\n IGAN:{:.4f} {:.4f} {:.4f} {}\n'.format(
                    phase, epoch_loss, epoch_acc, epoch_recallmn, epoch_precmn, epoch_f1mn, positive_other, epoch_recallig, epoch_precig, epoch_f1ig, negative_other)) # ***

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                
            if phase == 'val' and epoch_recallmn > best_recallmn: # ***
                best_recallmn = epoch_recallmn
            if phase == 'val' and epoch_recallig > best_recallig: # ***
                best_recallig = epoch_recallig  
                
            if phase == 'val' and epoch_precmn > best_precmn: # ***
                best_precmn = epoch_precmn
            if phase == 'val' and epoch_precig > best_precig: # ***
                best_precig = epoch_precig    
                
            if phase == 'val' and epoch_f1mn > best_f1mn: # ***
                best_f1mn = epoch_f1mn
            if phase == 'val' and epoch_f1ig > best_f1ig: # ***
                best_f1ig = epoch_f1ig   
                
        print()

    time_elapsed = time.time() - since
    pprint('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    pprint('Best val Acc: {:.4f} \n     Recall Precision F1_score\n MGN: {:.4f} {:.4f} {:.4f}\n IGAN:{:.4f} {:.4f} {:.4f}\n'.format(
                best_acc, best_recallmn, best_precmn, best_f1mn, best_recallig, best_precig, best_f1ig)) # ***

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [45]:
def train_mod(true_None, num_epoch):
    model = models.inception_v3(weights=true_None)
    model.aux_logits=False 
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, 2)
    # model_ft.aux_logits=False
    
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001)
    step_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
    model = train_model(model, criterion, optimizer, step_lr_scheduler, num_epochs=num_epoch)    
    return model

In [None]:
weight_name = "D:/P2023/WEIGHT/inception_50t.pt"

pprint(f"Model at {weight_name}")
model = train_mod(True, 50)
torch.save(model.state_dict(), weight_name)

Model at D:/P2023/WEIGHT/inception_50t.pt
Epoch 0/49
----------
train Loss: 0.5669 Accuracy: 0.6976 
     Recall Precision F1_score OTHER
 MGN: 0.5112 0.6454 0.5705 0
 IGAN:0.8182 0.7211 0.7666 0

val Loss: 0.4407 Accuracy: 0.8065 
     Recall Precision F1_score OTHER
 MGN: 0.6805 0.7974 0.7343 0
 IGAN:0.8881 0.8111 0.8478 0


Epoch 1/49
----------
train Loss: 0.5085 Accuracy: 0.7525 
     Recall Precision F1_score OTHER
 MGN: 0.6323 0.7069 0.6675 0
 IGAN:0.8303 0.7772 0.8029 0

val Loss: 0.3835 Accuracy: 0.8400 
     Recall Precision F1_score OTHER
 MGN: 0.7936 0.7981 0.7959 0
 IGAN:0.8700 0.8669 0.8685 0


Epoch 2/49
----------
train Loss: 0.4738 Accuracy: 0.7773 
     Recall Precision F1_score OTHER
 MGN: 0.6809 0.7333 0.7062 0
 IGAN:0.8397 0.8026 0.8207 0

val Loss: 0.3457 Accuracy: 0.8593 
     Recall Precision F1_score OTHER
 MGN: 0.8171 0.8235 0.8203 0
 IGAN:0.8866 0.8822 0.8844 0


Epoch 3/49
----------
train Loss: 0.4373 Accuracy: 0.8021 
     Recall Precision F1_score OTHER
 

In [None]:
weight_name = "D:/P2023/WEIGHT/inception_50n.pt"

pprint(f"Model at {weight_name}")
model = train_mod(None, 50)
torch.save(model.state_dict(), weight_name)

In [None]:
# model = train_mod(True, 50)
# torch.save(model.state_dict(), "D:/P2023/WEIGHT/res18_50t.pt")