In [18]:
import json, os
import torch, torchvision
from torch import nn
import matplotlib.pyplot as plt
from pathlib import Path
import numpy as np
from sentence_transformers import SentenceTransformer
import random
from PIL import Image
from torchvision.models import resnet50, ResNet50_Weights
import trainer

In [2]:
max_length = 225
amount_images_in_new_dataset = 150
batch_size = 32
dataset_save_dir = "encoded_data/data"
new_dataset_dir = "landing_site_classification"
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:

json_path= os.path.join("mars_dataset", "meta_clean.json")
json_data = json.load(open(json_path))
python_data = {}

for data_point in json_data:
    python_data[data_point] = json_data[data_point]["title"] + " " + json_data[data_point]["caption"]

In [4]:
python_data["ESP_014159_1670"]

'Fan in Southern Highlands Crater Many channels formed by water (when the climate of Mars was very different to that of today) cut the ancient highlands of Mars. Water running through these channels picks up rocky debris and carries it or rolls it along the channel bed. Occasionally these channels will empty into a crater or other low point in the terrain and the water will drop the material it is transporting. This material can build up in large fan-shaped mounds at the end of the channel. In this observation, this is likely what has happened. The fan-shaped mound (which appears bluish in this false-color image) sits at the end of a short channel. Analysis of spectroscopic data shows that the composition of this material indicates a history of interaction with liquid water. The full resolution version of this HiRISE image shows layering that indicates this material was dumped here in at least three separate episodes. Although they may once have been common, features like this are now 

In [5]:
model = SentenceTransformer('paraphrase-MiniLM-L12-v2')

class BertEncodedDataset(torch.utils.data.Dataset):
    def __init__(self, data_dict, model=model):
        self.classes = [name for name in data_dict]
        if not os.path.exists(dataset_save_dir):
            text = [model.encode(text, convert_to_tensor=True) for text in list(data_dict.values())]
            data = torch.stack(text)
            os.makedirs(dataset_save_dir.split("/")[0], exist_ok=True)
            # y = torch.arange(len(data))
            np.savez(dataset_save_dir, data.cpu())
        npzfile = np.load(dataset_save_dir + ".npz")
        self.encoded = torch.from_numpy(npzfile['arr_0'])

    def __len__(self) -> int:
        "Returns the total number of samples."
        return len(self.encoded)
    
    def __getitem__(self, index: int):
        return self.encoded[index], self.classes[index]

In [6]:
dataset = BertEncodedDataset(python_data)

In [7]:
cos = nn.CosineSimilarity(dim=0)

query = model.encode("flat, water, equator", convert_to_tensor=True)

similar_dict = {}

for desc, img in dataset:
    similar_val = cos(query.to(device), desc.to(device))
    similar_dict[similar_val.item()] = img

sorted_dict = dict(sorted(similar_dict.items(), reverse=True))

In [8]:
landing_site_images = list(sorted_dict.values())[:amount_images_in_new_dataset]
not_landing_site_images = random.choices(list(sorted_dict.values())[amount_images_in_new_dataset:], k=amount_images_in_new_dataset)

In [9]:
if not os.path.exists(new_dataset_dir):
    os.makedirs(new_dataset_dir)
    landing_site_path = new_dataset_dir + "/landing_site"
    not_landing_site_path = new_dataset_dir + "/not_landing_site"
    os.makedirs(landing_site_path)
    os.makedirs(not_landing_site_path)
    for file_name in landing_site_images:
        img = Image.open(f"mars_dataset/images/{file_name}.jpg")
        img.save(landing_site_path + f"/{file_name}.jpg")  
    for file_name in not_landing_site_images:
        img = Image.open(f"mars_dataset/images/{file_name}.jpg")
        img.save(not_landing_site_path + f"/{file_name}.jpg")  

In [10]:
model = resnet50()
model.conv1 = nn.Conv2d(10, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
model.fc = nn.Linear(2048, 17)
model.load_state_dict(torch.load("resnet50-sentinel2.pt"))
for param in model.parameters():
    param.requires_grad = False
model.conv1 = nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
model.fc = nn.Linear(2048, 1)
transform = ResNet50_Weights.DEFAULT.transforms()

### Try using a model that has 10 input channels, try loading a model from pytorch

In [11]:
dataset = torchvision.datasets.ImageFolder(new_dataset_dir, transform=transform)
n_test = int(0.15 * len(dataset))  # take ~10% for test
test_set = torch.utils.data.Subset(dataset, range(n_test))  # take first 10%
train_set = torch.utils.data.Subset(dataset, range(n_test, len(dataset))) 

train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_set, batch_size=batch_size)

In [12]:
loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [19]:
results = trainer.train(model=model,
                       train_dataloader=train_dataloader,
                       test_dataloader=test_dataloader,
                       optimizer=optimizer,
                       loss_fn=loss_fn,
                       epochs=5,
                       device=device)

  0%|          | 0/5 [00:00<?, ?it/s]


heloo


RuntimeError: result type Float can't be cast to the desired output type Long