In [1]:
import torch
import torchvision
import torch.nn as nn
from utils import compute_iou
from data.cbis_ddsm import CBIS_DDSM, get_loaders
import torch.optim as optim
from torch.optim import lr_scheduler
import time
from collections import defaultdict
from tqdm import tqdm




In [2]:
# Check if gpu is available

if not torch.cuda.is_available():
  raise Exception("GPU not availalbe. CPU training will be too slow.")

print("device name", torch.cuda.get_device_name(0))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device', device)

CUDA_LAUNCH_BLOCKING=1

device name NVIDIA GeForce GTX 1080 Ti
device cuda


In [3]:
def print_metrics(metrics, epoch_samples):
    outputs = []
    for k in metrics.keys():
        outputs.append("{}: {:4f}".format(k, metrics[k] / epoch_samples))

    print("Metrics: {}".format(", ".join(outputs)))

In [4]:
def calc_metrics(outputs, target, metrics, bce_weight=0.5):
    # bce = F.binary_cross_entropy_with_logits(pred, target)

    pred = torch.sigmoid(outputs)
    # dice = dice_loss(pred, target)

    # loss = bce * bce_weight + dice * (1 - bce_weight)

    metrics['bce'] += bce.data.cpu().numpy() * target.size(0)
    # metrics['dice'] += dice.data.cpu().numpy() * target.size(0)
    metrics['loss'] += loss.data.cpu().numpy() * target.size(0)
    
    
    

    # return loss

In [5]:
def test_model(model, testloader):
    
    model.eval()
    
    correct = 0
    total = 0
    
    with torch.no_grad():
        
        for i, batch in enumerate(tqdm(testloader)):

            inputs, target = batch['image'], batch['target']
            # mask = mask.mean(1,keepdim=True)  #transform from [batch_size, 3, size, size] to [batch,size, 1, size, size]
            inputs = inputs.to(device)
            target = target.to(device)
            
            outputs = model(inputs)
            
            # the class with the highest energy is what we choose as prediction
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    print(f'Test Acc: {100 * correct // total} %')

In [11]:
### Parameters

num_epochs = 10

num_class = 2 #binary problem  
# model = ResNetUNet(num_class).to(device)
model = torchvision.models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
# Here the size of each output sample is set to 2.
# Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).
model.fc = nn.Linear(num_ftrs, num_class)
model = model.to(device)

criterion = nn.CrossEntropyLoss()


# freeze backbone layers
# for l in model.base_layers:
#     for param in l.parameters():
#         param.requires_grad = False
        
        
# optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
scheduler = lr_scheduler.StepLR(optimizer, step_size=8, gamma=0.1)

#dataloader parameters

args={
    'data_path': "/datasets/mamografia/CBIS-DDSM_organized/images/preprocessed",
    "batch_size": 64,
    "size1": 220,
    "size": 200
}


train_loader, test_loader = get_loaders(args)

In [None]:


## Training Loop
best_loss = 1e10

for epoch in range(num_epochs):
    print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    print('-' * 10)

    since = time.time()
    model.train()
    
    metrics = defaultdict(float)
    epoch_samples = 0
    total = 0
    correct = 0

    print(f"Epoch {epoch} - Training...\n")
    # for inputs, labels in dataloaders[phase]:
    for i, batch in enumerate(tqdm(train_loader)):
        
        inputs, target = batch['image'], batch['target']
        
        # mask = mask.mean(1,keepdim=True)  #transform from [batch_size, 3, size, size] to [batch,size, 1, size, size]
        
        inputs = inputs.to(device)
        target = target.to(device)
        # mask = mask.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward
        # track history if only in train
        # with torch.set_grad_enabled(phase == 'train'):
        outputs = model(inputs)
        
        # loss = calc_loss(outputs, mask, metrics, bce_weight=1)
        
        loss = criterion(outputs, target)

        loss.backward()
        optimizer.step()

        # statistics
        epoch_samples += inputs.size(0)
        
        # print statistics
        running_loss = loss.item()
        
        _, predicted = torch.max(outputs.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()
        
        print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss:.3f}')
        # running_loss = 0.0

    scheduler.step()
    print(f'Train Acc: {100 * correct // total} %')
    
    #test
    print("Testing...")
    test_model(model, test_loader)
    


    time_elapsed = time.time() - since
    print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

# torch.save(model.state_dict(), "trained_model.pth")



Epoch 0/9
----------
Epoch 0 - Training...



  5%|▌         | 1/20 [00:37<11:46, 37.17s/it]

[1,     1] loss: 0.721


 15%|█▌        | 3/20 [00:40<02:42,  9.55s/it]

[1,     2] loss: 0.692
[1,     3] loss: 0.771


 20%|██        | 4/20 [00:41<01:33,  5.82s/it]

[1,     4] loss: 0.738


 30%|███       | 6/20 [01:23<02:54, 12.45s/it]

[1,     5] loss: 0.704
[1,     6] loss: 0.749


 40%|████      | 8/20 [01:31<01:31,  7.66s/it]

[1,     7] loss: 0.677
[1,     8] loss: 0.696


 50%|█████     | 10/20 [02:22<02:26, 14.65s/it]

[1,     9] loss: 0.668
[1,    10] loss: 0.756


 60%|██████    | 12/20 [02:27<01:05,  8.15s/it]

[1,    11] loss: 0.688
[1,    12] loss: 0.797


 70%|███████   | 14/20 [03:39<01:54, 19.13s/it]

[1,    13] loss: 0.695
[1,    14] loss: 0.718


 80%|████████  | 16/20 [03:40<00:38,  9.55s/it]

[1,    15] loss: 0.703
[1,    16] loss: 0.657


 90%|█████████ | 18/20 [04:38<00:33, 16.90s/it]

[1,    17] loss: 0.768
[1,    18] loss: 0.714


100%|██████████| 20/20 [04:38<00:00, 13.93s/it]


[1,    19] loss: 0.724
[1,    20] loss: 0.560
Testing...


 33%|███▎      | 2/6 [01:11<01:58, 29.72s/it]