In [1]:
from PIL import Image
import os

import torch
import torch.nn.functional as F
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, random_split

from tqdm import tqdm

import numpy as np
import matplotlib.pyplot as plt

In [2]:
model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:01<00:00, 85.0MB/s]


In [3]:
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [5]:
print(model.parameters())

<generator object Module.parameters at 0x7e756c5df300>


In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

class customResNet50MultiTask(nn.Module):
    def __init__(self, num_classes_task1=10, num_classes_task2=5, input_size=(3, 224, 224)):
        super(customResNet50MultiTask, self).__init__()

        # Load pretrained ResNet50
        self.resnet50 = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)

        # Freeze all the convolutional layers (ResNet50's layers before the fully connected layers)
        for param in self.resnet50.parameters():
            param.requires_grad = False

        # Modify the fully connected layer
        in_features = self.resnet50.fc.in_features
        self.resnet50.fc = nn.Identity()  # Remove the original fully connected layer

        # Task-specific layers
        # Task 1
        self.task1_fc = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes_task1)  # Output for task 1
        )

        # Task 2
        self.task2_fc = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes_task2)  # Output for task 2
        )

    def forward(self, x):
        # Shared layers from ResNet50
        x = self.resnet50.conv1(x)
        x = self.resnet50.bn1(x)
        x = self.resnet50.relu(x)
        x = self.resnet50.maxpool(x)

        x = self.resnet50.layer1(x)
        x = self.resnet50.layer2(x)
        x = self.resnet50.layer3(x)
        x = self.resnet50.layer4(x)

        x = self.resnet50.avgpool(x)
        x = torch.flatten(x, 1)  # Flatten before passing to the task-specific heads

        # Task-specific outputs
        task1_out = self.task1_fc(x)
        task2_out = self.task2_fc(x)

        # You can apply softmax or log_softmax depending on the loss function you're using
        task1_out = F.log_softmax(task1_out, dim=1)
        task2_out = F.log_softmax(task2_out, dim=1)

        return task1_out, task2_out

# Example of creating the model for 2 tasks
# Task 1 has 10 output classes, and Task 2 has 5 output classes
model = customResNet50MultiTask(num_classes_task1=10, num_classes_task2=5)
print(model)


customResNet50MultiTask(
  (resnet50): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Seque

In [None]:
predl = (torch.sigmoid(outputsl) > 0.5).float()
  train_correctl += predl.eq(labell.view_as(predl)).sum().item()

In [None]:
 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])