In [None]:
import argparse
from tqdm import tqdm
import cv2
import torch
import random
import numpy as np
from sklearn.manifold import TSNE
import timm
import matplotlib.pyplot as plt

from my_dataset2 import MyDataset, collate_skip_empty, colors_per_class


def fix_random_seeds():
    seed = 50
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
 
def get_features(dataset, batch, num_images):
    if torch.cuda.is_available():
        device = 'cuda'
    else:
        device = 'cpu'

    model = timm.create_model('resnet50', num_classes=4)
    model = model.load_state_dict(torch.load('C:\\Users\\rtm\Desktop\\tsne_new\\best.pt'))
    
    model.eval()
    model.to(device)

    dataset = MyDataset(dataset, num_images)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch, collate_fn=collate_skip_empty, shuffle=True)

    # feature 저장 변수
    features = None

    # label 저장, image_path 저장
    labels = []
    image_paths = []

    for batch in tqdm(dataloader, desc='Running the model inference'):
        images = batch['image'].to(device)
        labels += batch['label']
        image_paths += batch['image_path']

        with torch.no_grad():
            output = model.forward_features(images)

        current_features = output.cpu().numpy()
        if features is not None:
            features = np.concatenate((features, current_features))
        else:
            features = current_features

    return features, labels, image_paths


# scale and move the coordinates so they fit [0; 1] range
def scale_to_01_range(x):
    # compute the distribution range
    value_range = (np.max(x) - np.min(x))

    # move the distribution so that it starts from zero
    # by extracting the minimal value from all its values
    starts_from_zero = x - np.min(x)

    # make the distribution fit [0; 1] by dividing by its range
    return starts_from_zero / value_range


def scale_image(image, max_image_size):
    image_height, image_width, _ = image.shape

    scale = max(1, image_width / max_image_size, image_height / max_image_size)
    image_width = int(image_width / scale)
    image_height = int(image_height / scale)

    image = cv2.resize(image, (image_width, image_height))
    return image


def draw_rectangle_by_class(image, label):
    image_height, image_width, _ = image.shape

    # get the color corresponding to image class
    color = colors_per_class[label]
    image = cv2.rectangle(image, (0, 0), (image_width - 1, image_height - 1), color=color, thickness=5)

    return image


def compute_plot_coordinates(image, x, y, image_centers_area_size, offset):
    image_height, image_width, _ = image.shape

    # compute the image center coordinates on the plot
    center_x = int(image_centers_area_size * x) + offset

    # in matplotlib, the y axis is directed upward
    # to have the same here, we need to mirror the y coordinate
    center_y = int(image_centers_area_size * (1 - y)) + offset

    # knowing the image center, compute the coordinates of the top left and bottom right corner
    tl_x = center_x - int(image_width / 2)
    tl_y = center_y - int(image_height / 2)

    br_x = tl_x + image_width
    br_y = tl_y + image_height

    return tl_x, tl_y, br_x, br_y


# TSNE 에서 이미지로 나타내기
def visualize_tsne_images(tx, ty, images, labels, plot_size=1000, max_image_size=100):
   # 이미지 정중앙에 놓기
    offset = max_image_size // 2
    image_centers_area_size = plot_size - 2 * offset

    tsne_plot = 255 * np.ones((plot_size, plot_size, 3), np.uint8)

    for image_path, label, x, y in tqdm(
            zip(images, labels, tx, ty),
            desc='Building the T-SNE plot',
            total=len(images)
    ):
        image = cv2.imread(image_path)

        image = scale_image(image, max_image_size)
        image = draw_rectangle_by_class(image, label)

        tl_x, tl_y, br_x, br_y = compute_plot_coordinates(image, x, y, image_centers_area_size, offset)

        tsne_plot[tl_y:br_y, tl_x:br_x, :] = image

    plt.imshow(tsne_plot[:, :, ::-1])
    plt.show()

# TSNE 점으로 찍기
def visualize_tsne_points(tx, ty, labels):
    fig = plt.figure()
    ax = fig.add_subplot(111)

    for label in colors_per_class:
        indices = [i for i, l in enumerate(labels) if l == label]

        current_tx = np.take(tx, indices)
        current_ty = np.take(ty, indices)

        color = np.array([colors_per_class[label][::-1]], dtype=np.float) / 255

        ax.scatter(current_tx, current_ty, c=color, label=label)

    ax.legend(loc='best')

    plt.show()
    #plt.savefig('local_tsne_5000')


def visualize_tsne(tsne, images, labels, plot_size=1000, max_image_size=100):
    tx = tsne[:, 0]
    ty = tsne[:, 1]

    tx = scale_to_01_range(tx)
    ty = scale_to_01_range(ty)

    visualize_tsne_points(tx, ty, labels)

    #visualize_tsne_images(tx, ty, images, labels, plot_size=plot_size, max_image_size=max_image_size)


def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('--path', type=str, default='data')
    parser.add_argument('--batch', type=int, default=64)
    parser.add_argument('--num_images', type=int, default=500)
    args = parser.parse_args()

    fix_random_seeds()

    features, labels, image_paths = get_features(
        dataset=args.path,
        batch=args.batch,
        num_images=args.num_images
    )
    print(features.shape)

    tsne = TSNE(n_components=2,perplexity=20, n_iter=1000,).fit_transform(features)

    visualize_tsne(tsne, image_paths, labels)

if __name__ == '__main__':
    main()