In [1]:
import numpy as np 
import pandas as pd 
import pickle

from FDSSC import models, utils
from datasets.datasets import LocationChecker, FDSSCDataset
from matplotlib import pyplot as plt
import torch
from torch.utils.data import DataLoader

In [2]:
datadir = "data/stacked_images"
csv = "miningSites/masked_labels.csv"
id2idx = utils.get_id2idx(pd.read_csv(csv), "MAJOR_COMMODITY_CODE")
import pickle

pickle.dump(id2idx, open("data/id2idxMasked.p", "wb"))

In [5]:
print(id2idx)

{'NONE': 0, 'Ni': 1, 'Mn': 2, 'Ag': 3, 'Pb': 4, 'Cu': 5, 'Fe': 6, 'U': 7, 'Au': 8}


In [4]:
df = pd.read_csv(csv)

df["label"] = utils.create_idxlist(pd.read_csv(csv), "MAJOR_COMMODITY_CODE", id2idx)

df.to_csv("miningSites/labeled_masked_labels.csv", index=False)

In [4]:
datadir = "data/stacked_images"
csv = "miningSites/masked_labels.csv"
checker = LocationChecker(datadir, size=(9, 9), rigorous=False)
landsat8Data = FDSSCDataset(datadir, csv, lochecker=checker)
dataloader = DataLoader(landsat8Data, batch_size=16, shuffle=True, num_workers=1)
#id2idx = utils.get_id2idx(pd.read_csv(csv), "MAJOR_COMMODITY_CODE")
id2idx = pickle.load(open("data/id2idxMasked.p", "rb"))
idx2id = {value: key for key, value in id2idx.items()}
print(utils.freqdict(pd.read_csv(csv), "MAJOR_COMMODITY_CODE"))

{'MHEM': 4, 'DOLR': 1, 'GRNTK': 1, 'GRPH': 22, 'JADE': 113, 'TOZ': 1, 'Co': 6, 'SiO2': 4, 'TALC': 4, 'Sn': 5, 'OCR': 7, 'Ta': 1, 'OPAL': 143, 'MIC': 2, 'AGA': 2, 'GARN': 1, 'AMY': 2, 'FELD': 3, 'Pb': 82, 'ASBT': 10, 'BRL': 2, 'Cr': 8, 'Mn': 66, 'CAL': 5, 'FEST': 10, 'U': 52, 'Au': 307, 'Cu': 589, 'NPT': 1, 'LSND': 14, 'Th': 4, 'FLUR': 2, 'SALT': 31, 'PHOS': 26, 'GYPS': 68, 'ILM': 1, 'CYP': 2, 'CALI': 80, 'V': 1, 'MSLT': 162, 'GRVL': 38, 'Mo': 1, 'TRN': 1, 'QTZE': 55, 'FEORE': 14, 'SGRT': 5, 'KAO': 25, 'SLST': 31, 'SHA': 11, 'PALY': 4, 'MARB': 12, 'CLAY': 96, 'BAS': 1, 'EPS': 16, 'COAL': 34, 'QZ': 2, 'REE': 4, 'BAR': 73, 'DOL': 57, 'Ni': 18, 'GRNT': 58, 'GBRO': 1, 'GNSS': 7, 'Fe': 139, 'Ag': 28, 'SIS': 9, 'OSH': 1, 'HLY': 1, 'MAGS': 24, 'S': 1, 'SDST': 113, 'LMST': 68, 'PTM': 1, 'DIA': 39, 'HMIN': 38, 'COR': 1, 'W': 2, 'SLAT': 17, 'Ra': 1, 'Zn': 53, 'RHYO': 6, 'SAND': 131, 'ALU': 5, 'Al': 2, 'CELES': 23}


In [6]:
# training model
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = models.FerDSSC_model((1, 9, 9, 11), len(list(id2idx.keys())))
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
model.to(device)

FerDSSC_model(
  (input_spec_conv): Conv3d(1, 24, kernel_size=(1, 1, 7), stride=(1, 1, 1), padding=(0, 0, 3))
  (spectral_conv1): Spectral_conv(
    (bn_prelu): Bn_prelu(
      (bn): BatchNorm3d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (prelu): PReLU(num_parameters=1)
    )
    (conv3d): Conv3d(24, 12, kernel_size=(1, 1, 7), stride=(1, 1, 1), padding=(0, 0, 3))
  )
  (spectral_conv2): Spectral_conv(
    (bn_prelu): Bn_prelu(
      (bn): BatchNorm3d(36, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (prelu): PReLU(num_parameters=1)
    )
    (conv3d): Conv3d(36, 12, kernel_size=(1, 1, 7), stride=(1, 1, 1), padding=(0, 0, 3))
  )
  (spectral_conv3): Spectral_conv(
    (bn_prelu): Bn_prelu(
      (bn): BatchNorm3d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (prelu): PReLU(num_parameters=1)
    )
    (conv3d): Conv3d(48, 12, kernel_size=(1, 1, 7), stride=(1, 1, 1), padding=(0, 0, 3))
  )
  (bn_prelu1): Bn

In [7]:
utils.train(model, dataloader, epochs=1, loss_fn=loss_fn, optimizer=optimizer)


Training...
  Batch   194  of    195.    Loss: 4.45     Elapsed: 0:05:22.
  Average training loss: 4.44
  Training epoch took: 0:05:22


[4.438143434280004]

In [None]:
#callback for