In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import pandas as pd
import numpy as np
import os
from PIL import Image
import time
from tqdm import tqdm
from collections import OrderedDict


In [138]:

# Define dataset class
class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, root, dataframe, transform=None):
        self.root = root
        self.df = dataframe
        self.transform = transform

    def __getitem__(self, index):
        path = os.path.join(self.root, self.df.iloc[index]['path'])
        img = Image.open(path).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        label = self.df.iloc[index]['config']
        return img, label

    def __len__(self):
        return len(self.df)


In [139]:
# Define train_transform to perform data augmentation
train_transform = transforms.Compose([
    transforms.Resize(256),          # Resize the image to 256x256 pixels
    transforms.RandomCrop(224),     # Randomly crop the image to 224x224 pixels
    transforms.RandomHorizontalFlip(),  # Randomly flip the image horizontally
    transforms.ToTensor(),          # Convert the image to a tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize the image
])

# Define test_transform without data augmentation
test_transform = transforms.Compose([
    transforms.Resize(256),          # Resize the image to 256x256 pixels
    transforms.CenterCrop(224),     # Crop the image from the center to 224x224 pixels
    transforms.ToTensor(),          # Convert the image to a tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize the image
])

In [140]:
# Load dataset

root = 'E:/code/THESIS/printer_source'
df = pd.read_excel('E:/code/THESIS/printer_source/file2.xlsx', engine='openpyxl')
checkpoint_dir = 'E:/code/THESIS/printer_source/checkpoint'
num_epochs = 100
is_best= True


encoder = LabelEncoder()
df['config'] = encoder.fit_transform(df['config']) # encode labels
train_df, test_df = train_test_split(df, test_size=0.2, stratify=df['config']) # split into train and test sets
train_dataset = ImageDataset(root, train_df, transform=train_transform)
test_dataset = ImageDataset(root, test_df, transform=test_transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)


In [141]:
#mix up augmentation
def mixup_data(x, y, alpha=1.0):
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size()[0]
    index = torch.randperm(batch_size)
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

In [142]:
# Define model
model = models.resnet50(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(encoder.classes_))
device = torch.device("cuda")
model = model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Define scheduler
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)




In [143]:
def save_checkpoint(state, is_best, checkpoint_dir):
    checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint.pth').replace('\\','/')
    best_path = os.path.join(checkpoint_dir, 'best.pth').replace('\\','/')

    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    torch.save(state, checkpoint_path)
    if is_best:
        torch.save(state, best_path)


def load_checkpoint(model, optimizer, checkpoint_dir):
    checkpoint_path = os.path.join(checkpoint_dir, 'resnet50.pth').replace('\\','/')

    if os.path.isfile(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        return checkpoint['epoch'], checkpoint['best_acc']
    else:
        os.makedirs(checkpoint_dir, exist_ok=True)
        return 0, 0.0


def log_accuracy(epoch, train_acc, test_acc, checkpoint_dir):
    log_file = os.path.join(checkpoint_dir, 'log.txt')
    with open(log_file, 'a') as f:
        f.write(f"Epoch {epoch}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}\n")


In [144]:
import torch
import os

def auto_load_resume(model, optimizer, scheduler, checkpoint_dir, device):
    current_model_path = os.path.join(checkpoint_dir, 'current_model.pth')
    best_model_path = os.path.join(checkpoint_dir, 'best_model.pth')
    if not os.path.exists(current_model_path):
        return 0, 0.0
    else:
        print('Loading pretrained model!')
        checkpoint = torch.load(current_model_path, map_location=device)
        best_model = torch.load(best_model_path, map_location=device)
        new_state_dict = OrderedDict()
        for k, v in checkpoint['model_state_dict'].items():
            name = k[:]        #  = k[7:] if you used DataParallel
            new_state_dict[name] = v
        model.load_state_dict(new_state_dict)
        model = model.to(device)
        epoch = checkpoint['epoch']+1
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        print('Resume from %s' % 'epoch ' + str(epoch+1))
        return epoch, best_model['test_acc']


In [161]:
def train(model, device, train_loader, test_loader, optimizer, scheduler, criterion, mixup_criterion, num_epochs, checkpoint_dir):
    print(f"{device}")
    # Train loop
    if os.path.exists(checkpoint_dir):
        epoch,best_acc=auto_load_resume(model, optimizer, scheduler, checkpoint_dir, device)
        assert epoch < num_epochs
    for epoch in range(epoch, num_epochs):
        train_loss = 0.0
        train_correct = 0
        best_acc=0.0
        # epoch,best_acc=load_checkpoint(model, optimizer, checkpoint_dir)


        model.train()
        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            images, labels = images.to(device), labels.to(device)
            mixed_images, y_a, y_b, lam = mixup_data(images, labels)
            optimizer.zero_grad()
            outputs = model(mixed_images)
            loss = mixup_criterion(criterion, outputs, y_a, y_b, lam)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * mixed_images.size(0)
            _, preds = torch.max(outputs, 1)
            train_correct += torch.sum(preds == labels.data)
        
        train_loss = train_loss / len(train_loader.dataset)
        train_acc = train_correct.double() / len(train_loader.dataset)
        
        # Testing loop
        test_loss = 0.0
        test_correct = 0
        model.eval()
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                test_loss += loss.item() * images.size(0)
                _, preds = torch.max(outputs, 1)
                test_correct += torch.sum(preds == labels.data)
        test_loss = test_loss / len(test_loader.dataset)
        test_acc = test_correct.double() / len(test_loader.dataset)
        print(f'')
        
        # Update scheduler
        scheduler.step()

        # Print results
        print('Epoch [{}/{}], Train Loss: {:.4f}, Train Acc: {:.4f}, Test Loss: {:.4f}, Test Acc: {:.4f}'
            .format(epoch+1, num_epochs, train_loss, train_acc, test_loss, test_acc))

        # save checkpoint
        print('Saving checkpoint')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            # 'learning_rate': lr,
            # 'val_acc': val_accuracy,
            'test_acc': test_acc
        }, os.path.join(checkpoint_dir, 'current_model' + '.pth'))

        if test_acc > best_acc:
            print('Saving best model')
            
            best_acc = test_acc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                # 'learning_rate': lr,
                # 'val_acc': val_accuracy,
                'test_acc': test_acc
            }, os.path.join(checkpoint_dir, 'best_model' + '.pth'))

        # Log accuracy
        log_accuracy(epoch + 1, train_acc, test_acc, checkpoint_dir)




In [162]:
train(model, device, train_loader, test_loader, optimizer, scheduler, criterion, mixup_criterion, num_epochs, checkpoint_dir)

cuda
Loading pretrained model!
Resume from epoch 4


Epoch 4/100: 100%|██████████| 113/113 [00:22<00:00,  4.92it/s]



Epoch [4/100], Train Loss: 0.5903, Train Acc: 0.5506, Test Loss: 0.0488, Test Acc: 0.9911
Saving checkpoint
Saving best model


Epoch 5/100: 100%|██████████| 113/113 [00:23<00:00,  4.91it/s]



Epoch [5/100], Train Loss: 0.6212, Train Acc: 0.5817, Test Loss: 0.0480, Test Acc: 0.9911
Saving checkpoint
Saving best model


Epoch 6/100: 100%|██████████| 113/113 [00:23<00:00,  4.90it/s]



Epoch [6/100], Train Loss: 0.5521, Train Acc: 0.5892, Test Loss: 0.0497, Test Acc: 0.9911
Saving checkpoint
Saving best model


Epoch 7/100: 100%|██████████| 113/113 [00:23<00:00,  4.89it/s]



Epoch [7/100], Train Loss: 0.6212, Train Acc: 0.5633, Test Loss: 0.0479, Test Acc: 0.9922
Saving checkpoint
Saving best model


Epoch 8/100:  45%|████▌     | 51/113 [00:10<00:13,  4.76it/s]


KeyboardInterrupt: 

In [10]:
!python train.py -m resnet50
!ren checkpoint\resnet50 resnet50_config1 

^C


Access is denied.


In [3]:
dataframe = pd.read_excel('E:/code/THESIS/printer_source/file2.xlsx')


In [4]:
num_classes = len(dataframe['printer'].unique())
print("Number of output classes:", num_classes)


Number of output classes: 6
