In [1]:
import numpy as np
from PIL import Image
import os
import cv2
import torch 
import torch.nn as nn 
from torch.utils.data import DataLoader, Dataset, random_split, Subset
import torchvision
from torchvision.io import read_image
from torchvision import datasets, models, transforms
from tqdm import tqdm
import regex as re 
import pandas as pd
import os
import wandb

In [2]:
# # run this cell if running the notebook for the first time. 
def sort_fn(x):
    num1 = re.findall(r"image([\d]+)", x)
    return int(num1[0])

# Sort read in image filenames. 
image_dir = "./images/"
image_filenames = os.listdir(os.path.join(os.getcwd(), "images"))
image_filenames.sort(key= sort_fn)

# Fetch complete image dir. 
image_filenames = [os.path.join(image_dir, i) for i in image_filenames]
len(image_filenames)

labels = [1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1 ,1,1,1,0, 1, 0, 0, 0 ,1,0,1, 0, 0, 1, 0, 0, 1, 1,0, 1, 0,0,1, 1, 1, 1, 1, 1, 1,1,1,1,1,0,1,1,1,1,1,1,1,1,1,1,0,0,1,1,1,1,1,1,1,1,1,1,0,1,1,1,0,0,0,0,1,1,0,0,0,1,0,1,0,1,0,1,1,0,1,1,0,0,1,1,1,1,1,0,0,1,1,0,1,1,1,1,0,0,0,1,1,0,1]
pd.DataFrame({"image":image_filenames, "label":labels}).to_csv("train_data.csv", index=False)

In [2]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(device)

cuda


In [3]:
class WindmillDataset(Dataset):
    def __init__(self, transform=None):
        self.data_dir = 'windmill'
        valid_targets = os.listdir(os.path.join(self.data_dir, 'val', 'with_mill'))
        valid_backgrounds = os.listdir(os.path.join(self.data_dir, 'val', 'without'))
        self.valid_file_paths = valid_targets + valid_backgrounds

        train_targets = os.listdir(os.path.join(self.data_dir, 'train', 'with_mill'))
        train_backgrounds = os.listdir(os.path.join(self.data_dir, 'train', 'without'))
        self.train_file_paths = train_targets + train_backgrounds

        valid_data = pd.concat((pd.DataFrame({'name': valid_targets, 'label': 1}),
                             pd.DataFrame({'name': valid_backgrounds, 'label': 0})))
        self.valid_labels = torch.from_numpy(valid_data.iloc[:, 1].to_numpy()).to(torch.long)

        train_data = pd.concat((pd.DataFrame({'name': train_targets, 'label': 1}),
                             pd.DataFrame({'name': train_backgrounds, 'label': 0})))
        self.train_labels = torch.from_numpy(train_data.iloc[:, 1].to_numpy()).to(torch.long)

    def __len__(self):
        return len(self.train_file_paths) + len(self.valid_file_paths)

    def __getitem__(self, idx):
    # idx < 770 is validation data, others are train
        if idx < 770:
            img_path = self.valid_file_paths[idx]
            label = self.valid_labels[idx]
            full_path = os.path.join(os.path.join(self.data_dir, 'val', ('with_mill' if label else 'without')), img_path)
        else:
            real_id = idx - 770
            img_path = self.train_file_paths[real_id]
            label = self.train_labels[real_id]
            full_path = os.path.join(os.path.join(self.data_dir, 'train', ('with_mill' if label else 'without')), img_path)

        image = read_image(full_path)
        return image, label


In [4]:
class ImageDataset(Dataset):
    # Set output = False if doesn't want the print. 
    def __init__(self, image_filenames, labels):
        super().__init__()
        self.image_filenames = image_filenames
        self.labels = labels
        # Read all images into numpy array. 
        images = []
        for i in self.image_filenames:
            images.append(cv2.imread(i))
        assert len(images) == len(self.image_filenames)
        assert len(images) == len(labels)
        # Convert to tensor and transpose. 
        self.images = torch.tensor(np.array(images)).transpose(1, -1)

        print("The shape of all images: ", self.images.shape)

    def __getitem__(self, i):
        final_image = self.images[i]
        return final_image, self.labels[i]

    def __len__(self):
        return len(self.image_filenames)

In [21]:
def load_checkpoint(model, save_name, save_dir = "model_ckpt"):
    save_path = os.path.join(save_dir, save_name)
    assert os.path.exists(save_path), "Path doesn't exist!"
    print("Loading checkpoint...")
    checkpoint = torch.load(save_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    best_loss = checkpoint['loss']
    print("Done!")

    return best_loss


def train(model, optimizer, data_augment, scheduler, train_dataloader, validation_dataloader, criterion, num_epochs = 5, load_from_save =False, save_name = "model0", save_dir = "model_ckpt"):
    best_acc = 0.0
    if load_from_save:
        print("Loading the model...")
        best_acc = load_checkpoint(model, optimizer, save_name + ".tar", save_dir)
        print("Done!")
    print(optimizer.param_groups[0]["lr"])
    save_path = os.path.join(save_dir, save_name + ".tar")
    best_state_dict = {}

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        num_correct = 0
        for _, (inputs, labels) in enumerate(tqdm(train_dataloader)):
            inputs = inputs.float().to(device)
            if data_augment:
                inputs = data_augment(inputs)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            pred = model(inputs)
            loss = criterion(pred, labels)
            num_correct += (pred.argmax(1) == labels).sum().item()
            loss.backward()
            optimizer.step()
            if scheduler:
                scheduler.step()
            train_loss += loss.item()
            del inputs, labels, pred
        
        epoch_loss = train_loss / len(train_dataloader.dataset)
        print(f"Epoch {epoch} training loss: ", epoch_loss)
        valid_loss, valid_acc = evaluate(model, validation_dataloader=validation_dataloader, criterion=criterion)
        wandb.log({'valid epoch loss': valid_loss,
                    'valid accuracy': valid_acc,
                    'train epoch loss': epoch_loss, 
                    'train epoch accuracy': num_correct / len(train_dataloader.dataset)})

        if valid_acc > best_acc:
            if not os.path.exists(save_dir):
                os.mkdir(save_dir)
            
            best_acc = valid_acc
            state_dict = {
                'model_state_dict': model.state_dict(),
                'optim_state_dict': optimizer.state_dict(),
                'loss': best_acc
            }
            best_state_dict = state_dict
            
            torch.save(state_dict, save_path)
            print(f"Saved checkpoint to {save_path}")

        print(f"Best valid accuracy: {best_acc}")
    model.load_state_dict(best_state_dict['model_state_dict'])

def evaluate(model, validation_dataloader, criterion):
    model.eval()
    valid_loss = 0.0
    with torch.no_grad():
        correct = 0
        for _, (inputs, labels) in tqdm(enumerate(validation_dataloader), total=len(validation_dataloader)):
            inputs = inputs.float().to(device)
            labels = labels.to(device)
            pred = model(inputs)
            pred_labels = pred.argmax(1)
            correct += (pred_labels == labels).sum().item()
            loss = criterion(pred, labels)
            valid_loss += loss.item()
        acc = correct / len(validation_dataloader.dataset)
        valid_loss = valid_loss / len(validation_dataloader.dataset)
        print("Validation loss: ", valid_loss)
        print("Validation acc: ", acc)
        return valid_loss, acc

In [8]:
##### Windmill #####
batch_size = 8

ds = WindmillDataset()
valid_data, train_data = Subset(ds, torch.arange(0, 770)), Subset(ds, torch.arange(770, len(ds)))
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
validation_dataloader = DataLoader(valid_data, batch_size=batch_size, shuffle=True)

In [9]:
###### Finetune on windmill ######
lr = 1e-3
num_epochs = 6
model = models.efficientnet_b5(weights='DEFAULT')
model.classifier[1] = nn.Linear(model.classifier[1].in_features, 2)

model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
scheduler = None

data_augment = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(30),
])

wandb.login()
wandb.init(project='182')
wandb.run.name = 'windmill_effnet'
model_save_name = wandb.run.name
train(model, optimizer=optimizer, data_augment=data_augment, scheduler=scheduler, train_dataloader=train_dataloader, validation_dataloader=validation_dataloader, criterion=criterion, num_epochs=num_epochs, save_name=model_save_name)
wandb.run.finish()

[34m[1mwandb[0m: Currently logged in as: [33mjonluj[0m. Use [1m`wandb login --relogin`[0m to force relogin


0.001


100%|██████████| 388/388 [01:00<00:00,  6.40it/s]


Epoch 0 training loss:  0.024777112509492005


100%|██████████| 97/97 [00:03<00:00, 25.15it/s]


Validation loss:  0.009148720487883701
Validation acc:  0.9701298701298702
Saved checkpoint to model_ckpt/windmill_effnet.tar
Best valid accuracy: 0.9701298701298702


100%|██████████| 388/388 [00:56<00:00,  6.92it/s]


Epoch 1 training loss:  0.01323608439744285


100%|██████████| 97/97 [00:03<00:00, 30.09it/s]


Validation loss:  0.007766154914059831
Validation acc:  0.977922077922078
Saved checkpoint to model_ckpt/windmill_effnet.tar
Best valid accuracy: 0.977922077922078


100%|██████████| 388/388 [00:55<00:00,  6.97it/s]


Epoch 2 training loss:  0.011496804984335246


100%|██████████| 97/97 [00:03<00:00, 30.78it/s]


Validation loss:  0.005977574167107897
Validation acc:  0.987012987012987
Saved checkpoint to model_ckpt/windmill_effnet.tar
Best valid accuracy: 0.987012987012987


100%|██████████| 388/388 [00:53<00:00,  7.19it/s]


Epoch 3 training loss:  0.011169983834054532


100%|██████████| 97/97 [00:03<00:00, 26.09it/s]


Validation loss:  0.014339059190643202
Validation acc:  0.9506493506493506
Best valid accuracy: 0.987012987012987


100%|██████████| 388/388 [00:55<00:00,  6.99it/s]


Epoch 4 training loss:  0.0073895332858064065


100%|██████████| 97/97 [00:04<00:00, 24.19it/s]


Validation loss:  0.004247639591990087
Validation acc:  0.9844155844155844
Best valid accuracy: 0.987012987012987


100%|██████████| 388/388 [00:55<00:00,  6.96it/s]


Epoch 5 training loss:  0.008209522694209402


100%|██████████| 97/97 [00:03<00:00, 25.72it/s]

Validation loss:  0.007476578584913574
Validation acc:  0.9805194805194806
Best valid accuracy: 0.987012987012987





VBox(children=(Label(value='0.007 MB of 0.007 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
train epoch accuracy,▁▆█▇██
train epoch loss,█▃▃▃▁▁
valid accuracy,▅▆█▁█▇
valid epoch loss,▄▃▂█▁▃

0,1
train epoch accuracy,0.97774
train epoch loss,0.00821
valid accuracy,0.98052
valid epoch loss,0.00748


In [22]:
##### Main Task #####
np.random.seed(42)
torch.manual_seed(42)
df = pd.read_csv("train_data1.csv")
data = df.sample(frac=1)

# Build the dataset. 
batch_size = 8

ds = ImageDataset(data['image'].values.tolist(), data['label'].values.tolist())
train_dataset, validation_dataset, test_dataset = random_split(ds, [0.65, 0.2, 0.15], generator=torch.Generator().manual_seed(42))
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
validation_dataloader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=False)

The shape of all images:  torch.Size([406, 3, 500, 500])


In [12]:
##### Train Head #####
lr = 1e-3
num_epochs = 20

# Set all params besides the new head to be frozen
for param in model.parameters():
    param.requires_grad = False
model.classifier[1] = nn.Linear(model.classifier[1].in_features, 2)

data_augment = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(30),
])

# data_augment = None
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
# scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.1, steps_per_epoch=len(train_dataloader), epochs=num_epochs)
scheduler = None

wandb.init(project='182')
wandb.run.name = 'effnet_head'
model_save_name = wandb.run.name
train(model, optimizer=optimizer, data_augment=data_augment, scheduler=scheduler, train_dataloader=train_dataloader, validation_dataloader=validation_dataloader, criterion=criterion, num_epochs=num_epochs, save_name=model_save_name)
wandb.run.finish()

VBox(children=(Label(value='0.007 MB of 0.014 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.464667…

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016669539369953176, max=1.0…

0.001


100%|██████████| 33/33 [00:05<00:00,  6.51it/s]


Epoch 0 training loss:  0.08587763210137685


100%|██████████| 11/11 [00:01<00:00, 10.88it/s]


Validation loss:  0.09271145157697724
Validation acc:  0.524390243902439
Saved checkpoint to model_ckpt/effnet_head.tar
Best valid accuracy: 0.524390243902439


100%|██████████| 33/33 [00:05<00:00,  6.45it/s]


Epoch 1 training loss:  0.08428374632741466


100%|██████████| 11/11 [00:00<00:00, 11.94it/s]


Validation loss:  0.09249472981545984
Validation acc:  0.573170731707317
Saved checkpoint to model_ckpt/effnet_head.tar
Best valid accuracy: 0.573170731707317


100%|██████████| 33/33 [00:05<00:00,  6.48it/s]


Epoch 2 training loss:  0.08666895465417342


100%|██████████| 11/11 [00:00<00:00, 11.68it/s]


Validation loss:  0.09477801075795801
Validation acc:  0.43902439024390244
Best valid accuracy: 0.573170731707317


100%|██████████| 33/33 [00:04<00:00,  7.05it/s]


Epoch 3 training loss:  0.08329305052757263


100%|██████████| 11/11 [00:00<00:00, 12.24it/s]


Validation loss:  0.09379390463596438
Validation acc:  0.4878048780487805
Best valid accuracy: 0.573170731707317


100%|██████████| 33/33 [00:04<00:00,  7.01it/s]


Epoch 4 training loss:  0.08604120271223964


100%|██████████| 11/11 [00:00<00:00, 12.73it/s]


Validation loss:  0.0910384458739583
Validation acc:  0.5975609756097561
Saved checkpoint to model_ckpt/effnet_head.tar
Best valid accuracy: 0.5975609756097561


100%|██████████| 33/33 [00:04<00:00,  7.20it/s]


Epoch 5 training loss:  0.0827622464434667


100%|██████████| 11/11 [00:00<00:00, 13.26it/s]


Validation loss:  0.09265567616718572
Validation acc:  0.5121951219512195
Best valid accuracy: 0.5975609756097561


100%|██████████| 33/33 [00:04<00:00,  7.30it/s]


Epoch 6 training loss:  0.08462863144549457


100%|██████████| 11/11 [00:00<00:00, 13.37it/s]


Validation loss:  0.09369843398652426
Validation acc:  0.47560975609756095
Best valid accuracy: 0.5975609756097561


100%|██████████| 33/33 [00:04<00:00,  7.29it/s]


Epoch 7 training loss:  0.0838178118521517


100%|██████████| 11/11 [00:00<00:00, 12.83it/s]


Validation loss:  0.09108640580642514
Validation acc:  0.5975609756097561
Best valid accuracy: 0.5975609756097561


100%|██████████| 33/33 [00:04<00:00,  7.37it/s]


Epoch 8 training loss:  0.08562946917884277


100%|██████████| 11/11 [00:00<00:00, 12.75it/s]


Validation loss:  0.09058718201590747
Validation acc:  0.6097560975609756
Saved checkpoint to model_ckpt/effnet_head.tar
Best valid accuracy: 0.6097560975609756


100%|██████████| 33/33 [00:04<00:00,  7.23it/s]


Epoch 9 training loss:  0.0841957533901388


100%|██████████| 11/11 [00:00<00:00, 13.08it/s]


Validation loss:  0.09009753858170859
Validation acc:  0.6341463414634146
Saved checkpoint to model_ckpt/effnet_head.tar
Best valid accuracy: 0.6341463414634146


100%|██████████| 33/33 [00:04<00:00,  7.43it/s]


Epoch 10 training loss:  0.08445841773892894


100%|██████████| 11/11 [00:00<00:00, 13.05it/s]


Validation loss:  0.09229739119366902
Validation acc:  0.573170731707317
Best valid accuracy: 0.6341463414634146


100%|██████████| 33/33 [00:04<00:00,  7.41it/s]


Epoch 11 training loss:  0.08444302754871773


100%|██████████| 11/11 [00:00<00:00, 12.84it/s]


Validation loss:  0.08930245914110323
Validation acc:  0.5853658536585366
Best valid accuracy: 0.6341463414634146


100%|██████████| 33/33 [00:04<00:00,  7.34it/s]


Epoch 12 training loss:  0.08407877385616302


100%|██████████| 11/11 [00:00<00:00, 12.89it/s]


Validation loss:  0.08823384380922085
Validation acc:  0.5853658536585366
Best valid accuracy: 0.6341463414634146


100%|██████████| 33/33 [00:04<00:00,  7.42it/s]


Epoch 13 training loss:  0.08541931189370877


100%|██████████| 11/11 [00:00<00:00, 13.06it/s]


Validation loss:  0.08906597771295686
Validation acc:  0.5853658536585366
Best valid accuracy: 0.6341463414634146


100%|██████████| 33/33 [00:04<00:00,  7.43it/s]


Epoch 14 training loss:  0.08639389784498648


100%|██████████| 11/11 [00:00<00:00, 13.08it/s]


Validation loss:  0.08830613552070246
Validation acc:  0.5853658536585366
Best valid accuracy: 0.6341463414634146


100%|██████████| 33/33 [00:04<00:00,  7.41it/s]


Epoch 15 training loss:  0.08392567977760777


100%|██████████| 11/11 [00:00<00:00, 13.03it/s]


Validation loss:  0.08813217863803957
Validation acc:  0.5853658536585366
Best valid accuracy: 0.6341463414634146


100%|██████████| 33/33 [00:04<00:00,  7.38it/s]


Epoch 16 training loss:  0.08213936921321985


100%|██████████| 11/11 [00:00<00:00, 12.76it/s]


Validation loss:  0.08797520471782219
Validation acc:  0.6707317073170732
Saved checkpoint to model_ckpt/effnet_head.tar
Best valid accuracy: 0.6707317073170732


100%|██████████| 33/33 [00:04<00:00,  7.35it/s]


Epoch 17 training loss:  0.08412098116946942


100%|██████████| 11/11 [00:00<00:00, 13.03it/s]


Validation loss:  0.08947213175820141
Validation acc:  0.5853658536585366
Best valid accuracy: 0.6707317073170732


100%|██████████| 33/33 [00:04<00:00,  7.43it/s]


Epoch 18 training loss:  0.08591790948853348


100%|██████████| 11/11 [00:00<00:00, 13.09it/s]


Validation loss:  0.09112096850464983
Validation acc:  0.6219512195121951
Best valid accuracy: 0.6707317073170732


100%|██████████| 33/33 [00:04<00:00,  7.35it/s]


Epoch 19 training loss:  0.0843190923333168


100%|██████████| 11/11 [00:00<00:00, 13.10it/s]


Validation loss:  0.0907849495003863
Validation acc:  0.6219512195121951
Best valid accuracy: 0.6707317073170732


VBox(children=(Label(value='0.007 MB of 0.021 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.316454…

0,1
train epoch accuracy,▃▅▄▆▄▅▂▅▄▇▅▅▆▅█▄▇▅▁▃
train epoch loss,▇▄█▃▇▂▅▄▆▄▅▅▄▆█▄▁▄▇▄
valid accuracy,▄▅▁▂▆▃▂▆▆▇▅▅▅▅▅▅█▅▇▇
valid epoch loss,▆▆█▇▄▆▇▄▄▃▅▂▁▂▁▁▁▃▄▄

0,1
train epoch accuracy,0.58333
train epoch loss,0.08432
valid accuracy,0.62195
valid epoch loss,0.09078


In [23]:
##### Finetune Main Task #####
lr = 1e-4
num_epochs = 35

for param in model.parameters():
    param.requires_grad = True

data_augment = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(30),
])

# data_augment = None
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
# scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.1, steps_per_epoch=len(train_dataloader), epochs=num_epochs)
scheduler = None

wandb.init(project='182')
wandb.run.name = 'effnet_finetune'
model_save_name = wandb.run.name
train(model, optimizer=optimizer, data_augment=data_augment, scheduler=scheduler, train_dataloader=train_dataloader, validation_dataloader=validation_dataloader, criterion=criterion, num_epochs=num_epochs, save_name=model_save_name)
wandb.run.finish()

0.0001


100%|██████████| 33/33 [00:10<00:00,  3.03it/s]


Epoch 0 training loss:  0.0815667679364031


100%|██████████| 11/11 [00:00<00:00, 13.02it/s]


Validation loss:  0.09726313483424304
Validation acc:  0.4024390243902439
Saved checkpoint to model_ckpt/effnet_noaug.tar
Best valid accuracy: 0.4024390243902439


100%|██████████| 33/33 [00:10<00:00,  3.02it/s]


Epoch 1 training loss:  0.058092874390157784


100%|██████████| 11/11 [00:00<00:00, 13.03it/s]


Validation loss:  0.08640126193441995
Validation acc:  0.6463414634146342
Saved checkpoint to model_ckpt/effnet_noaug.tar
Best valid accuracy: 0.6463414634146342


100%|██████████| 33/33 [00:10<00:00,  3.01it/s]


Epoch 2 training loss:  0.03618222201299487


100%|██████████| 11/11 [00:00<00:00, 12.99it/s]


Validation loss:  0.09853091000056849
Validation acc:  0.6341463414634146
Best valid accuracy: 0.6463414634146342


100%|██████████| 33/33 [00:10<00:00,  3.01it/s]


Epoch 3 training loss:  0.025319902776655825


100%|██████████| 11/11 [00:00<00:00, 12.95it/s]


Validation loss:  0.08423895152603708
Validation acc:  0.7195121951219512
Saved checkpoint to model_ckpt/effnet_noaug.tar
Best valid accuracy: 0.7195121951219512


100%|██████████| 33/33 [00:11<00:00,  3.00it/s]


Epoch 4 training loss:  0.01929018641539821


100%|██████████| 11/11 [00:00<00:00, 12.68it/s]


Validation loss:  0.07943899638769103
Validation acc:  0.6829268292682927
Best valid accuracy: 0.7195121951219512


100%|██████████| 33/33 [00:11<00:00,  2.99it/s]


Epoch 5 training loss:  0.01958513229549157


100%|██████████| 11/11 [00:00<00:00, 13.02it/s]


Validation loss:  0.10412375083783777
Validation acc:  0.6707317073170732
Best valid accuracy: 0.7195121951219512


100%|██████████| 33/33 [00:10<00:00,  3.00it/s]


Epoch 6 training loss:  0.016570445545243496


100%|██████████| 11/11 [00:00<00:00, 13.02it/s]


Validation loss:  0.12288490010470879
Validation acc:  0.6463414634146342
Best valid accuracy: 0.7195121951219512


100%|██████████| 33/33 [00:10<00:00,  3.01it/s]


Epoch 7 training loss:  0.02003734451224745


100%|██████████| 11/11 [00:00<00:00, 13.02it/s]


Validation loss:  0.11110295373492124
Validation acc:  0.7073170731707317
Best valid accuracy: 0.7195121951219512


100%|██████████| 33/33 [00:10<00:00,  3.00it/s]


Epoch 8 training loss:  0.013938433079712206


100%|██████████| 11/11 [00:00<00:00, 13.02it/s]


Validation loss:  0.097376715846178
Validation acc:  0.7073170731707317
Best valid accuracy: 0.7195121951219512


100%|██████████| 33/33 [00:11<00:00,  2.99it/s]


Epoch 9 training loss:  0.009185508447417029


100%|██████████| 11/11 [00:00<00:00, 12.56it/s]


Validation loss:  0.06737775355577469
Validation acc:  0.7682926829268293
Saved checkpoint to model_ckpt/effnet_noaug.tar
Best valid accuracy: 0.7682926829268293


100%|██████████| 33/33 [00:11<00:00,  2.99it/s]


Epoch 10 training loss:  0.010866641132231576


100%|██████████| 11/11 [00:00<00:00, 13.02it/s]


Validation loss:  0.0904769733003
Validation acc:  0.7317073170731707
Best valid accuracy: 0.7682926829268293


100%|██████████| 33/33 [00:10<00:00,  3.00it/s]


Epoch 11 training loss:  0.011540917515627701


100%|██████████| 11/11 [00:00<00:00, 13.03it/s]


Validation loss:  0.09516352724011351
Validation acc:  0.6707317073170732
Best valid accuracy: 0.7682926829268293


100%|██████████| 33/33 [00:10<00:00,  3.00it/s]


Epoch 12 training loss:  0.010568534674045319


100%|██████████| 11/11 [00:00<00:00, 13.00it/s]


Validation loss:  0.0992669808428462
Validation acc:  0.7195121951219512
Best valid accuracy: 0.7682926829268293


100%|██████████| 33/33 [00:11<00:00,  3.00it/s]


Epoch 13 training loss:  0.009826469771338232


100%|██████████| 11/11 [00:00<00:00, 13.01it/s]


Validation loss:  0.1135524535415376
Validation acc:  0.7317073170731707
Best valid accuracy: 0.7682926829268293


100%|██████████| 33/33 [00:11<00:00,  3.00it/s]


Epoch 14 training loss:  0.009338908631241682


100%|██████████| 11/11 [00:00<00:00, 12.98it/s]


Validation loss:  0.09450870016362609
Validation acc:  0.7073170731707317
Best valid accuracy: 0.7682926829268293


100%|██████████| 33/33 [00:10<00:00,  3.00it/s]


Epoch 15 training loss:  0.009068369004649647


100%|██████████| 11/11 [00:00<00:00, 13.01it/s]


Validation loss:  0.10028567673956476
Validation acc:  0.7317073170731707
Best valid accuracy: 0.7682926829268293


100%|██████████| 33/33 [00:11<00:00,  2.99it/s]


Epoch 16 training loss:  0.007763538771982756


100%|██████████| 11/11 [00:00<00:00, 12.74it/s]


Validation loss:  0.12893485632248042
Validation acc:  0.6951219512195121
Best valid accuracy: 0.7682926829268293


100%|██████████| 33/33 [00:11<00:00,  2.99it/s]


Epoch 17 training loss:  0.005588942304408799


100%|██████████| 11/11 [00:00<00:00, 13.00it/s]


Validation loss:  0.10726450220113848
Validation acc:  0.7560975609756098
Best valid accuracy: 0.7682926829268293


100%|██████████| 33/33 [00:11<00:00,  3.00it/s]


Epoch 18 training loss:  0.00544339223856558


100%|██████████| 11/11 [00:00<00:00, 13.00it/s]


Validation loss:  0.13545331567889307
Validation acc:  0.7073170731707317
Best valid accuracy: 0.7682926829268293


100%|██████████| 33/33 [00:11<00:00,  3.00it/s]


Epoch 19 training loss:  0.006909045330813211


100%|██████████| 11/11 [00:00<00:00, 13.09it/s]


Validation loss:  0.15436187804472157
Validation acc:  0.7439024390243902
Best valid accuracy: 0.7682926829268293


100%|██████████| 33/33 [00:10<00:00,  3.01it/s]


Epoch 20 training loss:  0.008136758639011532


100%|██████████| 11/11 [00:00<00:00, 12.99it/s]


Validation loss:  0.11372345031761541
Validation acc:  0.6951219512195121
Best valid accuracy: 0.7682926829268293


100%|██████████| 33/33 [00:10<00:00,  3.00it/s]


Epoch 21 training loss:  0.006296178176963815


100%|██████████| 11/11 [00:00<00:00, 13.03it/s]


Validation loss:  0.12649117446527247
Validation acc:  0.7682926829268293
Best valid accuracy: 0.7682926829268293


100%|██████████| 33/33 [00:11<00:00,  2.99it/s]


Epoch 22 training loss:  0.008991835580673069


100%|██████████| 11/11 [00:00<00:00, 12.71it/s]


Validation loss:  0.1523359352665976
Validation acc:  0.6951219512195121
Best valid accuracy: 0.7682926829268293


100%|██████████| 33/33 [00:11<00:00,  2.99it/s]


Epoch 23 training loss:  0.010300000535966265


100%|██████████| 11/11 [00:00<00:00, 13.01it/s]


Validation loss:  0.11836741801078726
Validation acc:  0.7439024390243902
Best valid accuracy: 0.7682926829268293


100%|██████████| 33/33 [00:11<00:00,  3.00it/s]


Epoch 24 training loss:  0.007758659672672917


100%|██████████| 11/11 [00:00<00:00, 12.99it/s]


Validation loss:  0.09007418582715639
Validation acc:  0.7560975609756098
Best valid accuracy: 0.7682926829268293


100%|██████████| 33/33 [00:11<00:00,  3.00it/s]


Epoch 25 training loss:  0.008435708272648326


100%|██████████| 11/11 [00:00<00:00, 13.02it/s]


Validation loss:  0.09907317624949827
Validation acc:  0.7439024390243902
Best valid accuracy: 0.7682926829268293


100%|██████████| 33/33 [00:11<00:00,  3.00it/s]


Epoch 26 training loss:  0.005178943783958499


100%|██████████| 11/11 [00:00<00:00, 12.97it/s]


Validation loss:  0.10199363284358164
Validation acc:  0.7804878048780488
Saved checkpoint to model_ckpt/effnet_noaug.tar
Best valid accuracy: 0.7804878048780488


100%|██████████| 33/33 [00:11<00:00,  2.98it/s]


Epoch 27 training loss:  0.0038361384657477947


100%|██████████| 11/11 [00:00<00:00, 12.72it/s]


Validation loss:  0.11932191361741322
Validation acc:  0.7560975609756098
Best valid accuracy: 0.7804878048780488


100%|██████████| 33/33 [00:11<00:00,  2.99it/s]


Epoch 28 training loss:  0.005100196288725728


100%|██████████| 11/11 [00:00<00:00, 13.00it/s]


Validation loss:  0.13467767748345688
Validation acc:  0.7682926829268293
Best valid accuracy: 0.7804878048780488


100%|██████████| 33/33 [00:11<00:00,  3.00it/s]


Epoch 29 training loss:  0.0070383351924123636


100%|██████████| 11/11 [00:00<00:00, 13.02it/s]


Validation loss:  0.13882970155739202
Validation acc:  0.6341463414634146
Best valid accuracy: 0.7804878048780488


100%|██████████| 33/33 [00:10<00:00,  3.00it/s]


Epoch 30 training loss:  0.004776536777860725


100%|██████████| 11/11 [00:00<00:00, 13.05it/s]


Validation loss:  0.1395548509388435
Validation acc:  0.7195121951219512
Best valid accuracy: 0.7804878048780488


100%|██████████| 33/33 [00:11<00:00,  2.99it/s]


Epoch 31 training loss:  0.0045548945230565905


100%|██████████| 11/11 [00:00<00:00, 12.75it/s]


Validation loss:  0.134148759994565
Validation acc:  0.6951219512195121
Best valid accuracy: 0.7804878048780488


100%|██████████| 33/33 [00:11<00:00,  2.99it/s]


Epoch 32 training loss:  0.005932493422623145


100%|██████████| 11/11 [00:00<00:00, 13.02it/s]


Validation loss:  0.12054907380625969
Validation acc:  0.7073170731707317
Best valid accuracy: 0.7804878048780488


100%|██████████| 33/33 [00:11<00:00,  3.00it/s]


Epoch 33 training loss:  0.0043383447176011305


100%|██████████| 11/11 [00:00<00:00, 13.01it/s]


Validation loss:  0.1152022325592797
Validation acc:  0.6829268292682927
Best valid accuracy: 0.7804878048780488


100%|██████████| 33/33 [00:10<00:00,  3.00it/s]


Epoch 34 training loss:  0.0036534434010250015


100%|██████████| 11/11 [00:00<00:00, 13.05it/s]


Validation loss:  0.11528771861297328
Validation acc:  0.7195121951219512
Best valid accuracy: 0.7804878048780488


0,1
train epoch accuracy,▁▅▇▇▇▇▇▇▇██████████████████████████
train epoch loss,█▆▄▃▂▂▂▂▂▁▂▂▂▂▂▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁
valid accuracy,▁▆▅▇▆▆▆▇▇█▇▆▇▇▇▇▆█▇▇▆█▆▇█▇███▅▇▆▇▆▇
valid epoch loss,▃▃▄▂▂▄▅▅▃▁▃▃▄▅▃▄▆▄▆█▅▆█▅▃▄▄▅▆▇▇▆▅▅▅

0,1
train epoch accuracy,0.98864
train epoch loss,0.00365
valid accuracy,0.71951
valid epoch loss,0.11529


In [26]:
def test_model(model, test_dataloader, criterion):
    model.eval()
    with torch.no_grad():
        correct = 0
        for _, (inputs, labels) in tqdm(enumerate(test_dataloader), total=len(test_dataloader)):
            inputs = inputs.float().to(device)
            labels = labels.to(device)
            pred = model(inputs)
            pred_labels = pred.argmax(1)
            correct += (pred_labels == labels).sum().item()
        acc = correct / len(test_dataloader.dataset)
        print("Test acc: ", acc)

In [28]:
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
model = models.efficientnet_b5()
model.classifier[1] = nn.Linear(model.classifier[1].in_features, 2)
load_checkpoint(model, save_name="effnet_noaug.tar")
model = model.to(device)
test_model(model, test_dataloader, nn.CrossEntropyLoss())

Loading checkpoint...
Done!


100%|██████████| 8/8 [00:00<00:00, 12.60it/s]

Test acc:  0.6666666666666666



