In [18]:
# Import the packages
import torch
import os
import argparse
import numpy as np
import os.path as osp
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import normalize
from tqdm import tqdm
#from torcheval.metrics import R2Score # To be implemented
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Linear
from torch.nn import BatchNorm1d
import torch.optim as optim

from aux_func import*
from datasets.datset_process import *

# Import the pretrained default model resnet18/resnet50/resnet101
from torchvision.models import resnet50

In [None]:
# Define arg parser
seed=200
paser = argparse.ArgumentParser()
args = paser.parse_args("")
np.random.seed(200)
torch.manual_seed(seed)
device=input("Enter cuda or cpu for device type")
device = torch.device(device)
#'cuda' if torch.cuda.is_available() else
device

In [8]:
# Take user inputs 
args.dataset='miniimagenet'
args.data_path=''
args.num_classes=64 # By default for miniimagenet
args.image_size=84

# FSL definitions
args.num_ways=5 # Number of classes per batch
args.k_shot=5 # number of Images per class
args.query=15 # Query set of the FSL
args.unlabel=50 # Number of unlabel samples per class 
args.step=5 # Select how many unlabeled data for each class in one iteration.
args.threshold=0.2  # Since we have 5 classes in each support set. So if all the classes are equally probable then mininmum p=0.2

# set in semi-supervised few-shot learning
num_support = args.k_shot * args.num_ways
num_query = args.query * args.num_ways
num_unlabeled = args.unlabel * args.num_ways

# Training or testing definitions 
args.episodes=600

In [None]:
# Number of sets of unlabeled data
num_select = int(args.unlabel / args.steps)

In [None]:
# Import the resnet model and define the model to be used 
model=resnet50(args.num_classes,pretrained=False)

# define last layer

'''
The input shape of last layer is found from the output shape of the resnet layer defined as 
'''
last_layer=nn.Dense(in_channels=2048*3*3,out_channels=args.num_ways)

In [7]:
#  Get the features from the resnet model

def get_features(model,input):
    '''
    The function first checks if the input batch size exceeds a desired batch size. If it does, the input batch is split into smaller batches of size 64, and the 
    ResNet model is called on each smaller batch using the model function with the return_feat=True argument to return the output features in addition to the classification results. 
    The output features are then detached from the computation graph, transferred to the CPU, and appended to a list embed. 
    Once all batches have been processed, the list of output features is concatenated using torch.cat to form a single tensor embed.
    If the input batch size is less than or equal to the desired batch size, the ResNet model is called once with the input batch using the model function with the return_feat=True argument to return the output features.
    Finally, the function checks if the shape of the output features embed matches the shape of the input batch, and returns the output features as a NumPy array using the numpy() method.
    '''
    batch_size = 64  # Use the desired batch size
    # Check to prevent the input shape from exceeding the desired batch size
    if input.shape[0] > batch_size:
        embed = []
        i = 0
        while i <= input.shape[0]-1:
            embed.append(model(input[i:i+batch_size].cuda(), return_feat=True).detach().cpu())
            i += batch_size
        embed = torch.cat(embed)
    else:
        embed = model(input.cuda(), return_feat=True).detach().cpu()
    assert embed.shape[0] == input.shape[0] # Check if input shape = embed shape  as we will be working on input shape.
    return embed.numpy()

In [None]:
def train_loop(model, dataset, loss_fn, optimizer, inputs, targets):
    # Define forward function
    def forward_fn(data, label):
        logits = model(data)
        loss = loss_fn(logits, label) + entropy_loss(logits) # Combination of two loss functions used here. The entropy_loss act as regularizer
        # loss = loss_fn(logits, label)
        return loss, logits
    
    def train_step(data, label):
        optimizer.zero_grad()
        loss, logits = forward_fn(data, label)
        loss.backward()
        optimizer.step()
        return loss.item(), logits
    
    model.train()
    inputs = torch.tensor(inputs)
    targets = torch.tensor(targets)
    loss, logits = train_step(inputs, targets)
    return loss, logits

In [None]:
def get_preds(out):
    preds = torch.argmin(out, dim=0).item()
    return preds, preds

In [None]:

def get_preds_position_(unlabel_out, position, _postion, thres=0.001):
    length = len(position)
    r = []
    un_idx = []
    for idx in range(length):
        pos = position[idx]
        _pos = _postion[idx]
        _out = unlabel_out[idx][pos]
        out = F.softmax(_out,dim=0)  # Check if dim=0 or 1
        if len(pos)==1:
            un_idx.append(idx)
            continue
        conf =  torch.argmin(out).item()
        if conf>thres:
            un_idx.append(idx)
            if len(_pos)==0:
                r.append(torch.argmin(out, dim=0).item().asnumpy())  # check if asnumpy works here or not
            else:
                r.append(_pos[-1])
            continue
        t, _ = get_preds(out)
        a = pos[t]
        _postion[idx].append(a)
        position[idx].remove(a)
        r.append(a)
    return np.asarray(r), un_idx

In [None]:
# Define the dataset and the respective loaders


In [None]:
# Start the training process 

# in the below code I am not using query set which should be concated with unlabeled data.

for data in tqdm(train_loader):

        # create different sets of data from the train loader
        data = data.to(device)
        targets = torch.arange(args.way).repeat(args.shot+args.query+args.unlabel).long()
    
        support_data = data[:num_support]
        query_data = data[num_support:num_support+num_query]
        unlabel_data = data[num_support+num_query:]

        support_inputs = normalize(get_features(model, support_data))  # get feature embeddings for 
        support_targets = targets[:num_support].cpu().numpy()

        query_inputs = normalize(get_features(model, query_data))
        query_targets = targets[num_support:num_support+num_query].cpu().numpy()

        unlabel_inputs = normalize(get_features(model, unlabel_data))
        unlabel_targets = targets[num_support+num_query:].cpu().numpy()

        # The classifier has already been decided as linear classifier with a single dense layer and the output dimension=5

        ori_index = [x for x in range(250)]  # Store the index position of 250 images
        _POSITION = [[] for _ in range(250)] # Create a 2D list to store the list of 5 classes in passed along with the image batch.
        POSITION = [[0, 1, 2, 3, 4] for _ in range(250)] # [0,1,2,3,4] was chosen for encoding the 5 classes
        
        # Define the loss criterion and the SGD optimizer used here for initial training of model.
        criterion = nn.CrossEntropyLoss(sparse=True, reduction='mean')
        optimizer = torch.optim.SGD(last_layer.trainable_params(), learning_rate = 1e-3, momentum=0.9, weight_decay=5e-4)  # weight decay is for L2 regularization.

        # Begin initial training

        print('\n********************************************  Initial training the model')
        for epoch in range(100):
              loss=train_loop(last_layer, None, criterion, optimizer, support_inputs, support_targets)
              print(f"Train_Epoch: {epoch}  Train_Loss: {loss}")
        
        # Start of code using complimentary labels
        print('\n********************************************  Training with complimentatry labels')
        while(True):
            select_idx=[]
            unlabel_out = last_layer(unlabel_inputs)
            nl_pred, unselect_idx = get_preds_position_(unlabel_out, POSITION, _POSITION, args.threshold)
            select_idx = [x for x in ori_index if x not in unselect_idx]
            _unlabel_embeddings = unlabel_inputs[select_idx]
            _unlabel_t = unlabel_targets[select_idx]
            nl_pred = nl_pred[select_idx]
            optimizer_NL = torch.optim.SGD(last_layer.trainable_params(), learning_rate = 1e-3, momentum=0.9, weight_decay=5e-4)
            for epoch in range(10):
                loss = train_loop_NL(clf, None, NL_loss, optimizer_NL, _unlabel_embeddings, nl_pred)
                print(f"Epoch: {epoch}  Loss: {loss}")
            
            # Break condition no negative label found below threshold condition
            print(nl_pred)
            if(len(nl_pred)==0):
                  break
            class_num = [0 for _ in range(5)]
            pseudo_label = []
            index_pl = []
            for idx in range(len(POSITION)):
                item = POSITION[idx]
                if len(item) == 1:
                    lab = item[0]
                    pseudo_label.append(item[-1])
                    class_num[lab] += 1
                    index_pl.append(idx)
            class_num = [item + 8 for item in class_num]
            max_ = max(class_num) * 1.0
            pseudo_label = np.asarray(pseudo_label)
            t1_ = unlabel_inputs[index_pl]
            t2_ = torch.tensor(pseudo_label, dtype=torch.int64)
            for epoch in range(100):
                loss=train_loop(last_layer, None, criterion, optimizer, t1_, t2_)
                print(f"Epoch: {epoch}  Loss: {loss}")
        
        

