In [None]:
import timm
import torch
import collections
from model import *
from dataset import *
import os
import torch_geometric

# for unimodal ViT

In [None]:
import timm
import torch

model_path = 'geolink_vit_large_patch16_224.pth' # path to ViT checkpoint dir
checkpoint = torch.load(model_path, map_location='cpu')
config = checkpoint['model_config']

model = timm.create_model(
    config['architecture'], 
    pretrained=False, 
    num_classes=config['num_classes'], 
    global_pool=config['global_pool']
)
model.load_state_dict(checkpoint['model_state_dict'])
# using sinusoidal position embedding with no back-propogation
model.pos_embed.requires_grad = config['pos_embed_requires_grad'] 

model

# for multimodal GeoLink

In [None]:
import timm
import torch
from model import *
from dataset import *

'''
    You can use the following code to obtain the fusion embeddings.
    The default output embeedings are from the 7/11/15th layers of ViT-L and the hybrid RS-OSM embeddings.
    The hybrid RS-OSM embeddings are generated by the integration of features from last(23th) layer of ViT-L
    and OSM embeddings (detail structure are shown in the right of Fig 8b in the paper).
    Such embeddings can be applied to downstream tasks through a task-specific decoder, like UperNet for semantic segmentation.
'''

ckpt_fp = r'geolink_mutimodal_vit_large_patch16_224.pth' # path to multimodal geolink checkpoint dir
checkpoint = torch.load(ckpt_fp, map_location='cpu')
config = checkpoint['model_config']
img_encoder = timm.create_model(
    config['architecture'], 
    pretrained=False, 
    num_classes=config['num_classes'], 
    global_pool=config['global_pool']
)

osm_encoder = OSMHeteroGAT()
geolink = GeoLink(img_encoder, osm_encoder)
msg = geolink.load_state_dict(checkpoint['model_state_dict'])
print(msg)

multi_encoder = GeoLink_Fusion_Embedding(geolink, output_layers=[7, 11, 15, 23]) # 23 means the fusion embedding with shape of [256,14,14]

data_root = r'./example_data/ufz_example'
samples = os.listdir(os.path.join(data_root, 'graph'))

trainset = DownstreamDataset(data_root, file_names=samples)
trainloader = torch_geometric.data.DataLoader(trainset, batch_size=2,
                             pin_memory=True, num_workers=1, drop_last=True, shuffle=True)
train_iter = iter(trainloader)
# 获取一个batch
img, graph = next(train_iter)
result = multi_encoder(img, graph) #list size [4, batch_size, embed_dim, 14, 14]

# for multimodal segmentation

In [None]:
import timm
import torch
from model import *
from dataset import *

ckpt_fp = r'geolink_mutimodal_vit_large_patch16_224.pth' # path to multimodal geolink checkpoint dir
checkpoint = torch.load(ckpt_fp, map_location='cpu')
config = checkpoint['model_config']
img_encoder = timm.create_model(
    config['architecture'], 
    pretrained=False, 
    num_classes=config['num_classes'], 
    global_pool=config['global_pool']
)

osm_encoder = OSMHeteroGAT()
geolink = GeoLink(img_encoder, osm_encoder)
msg = geolink.load_state_dict(checkpoint['model_state_dict'])
print(msg)

multi_encoder = GeoLink_Fusion_Embedding(geolink, output_layers=[7, 11, 15, 23])
model = SegUPerNet(encoder=multi_encoder, num_classes=9, channels=512)

data_root = r'./example_data/ufz_example'
samples = os.listdir(os.path.join(data_root, 'graph'))

trainset = DownstreamDataset_UFZ(data_root, file_names=samples)
trainloader = torch_geometric.data.DataLoader(trainset, batch_size=2,
                             pin_memory=True, num_workers=1, drop_last=True, shuffle=True)
train_iter = iter(trainloader)
# 获取一个batch
img, graph, label = next(train_iter)
result = model(img, graph) #list [batch_size, cls_num, 224, 224]

# prepare OSM graph

In [None]:
example_path = r'./example_data/osm2graph_example/' # path to the osm example dir
name = 'w221548378_US_21'
polygon_file = gpd.read_file(os.path.join(example_path, name, name + '_polygon.geojson'))
polygon_file

In [None]:
import geopandas as gpd
import pickle
import os
from prepare_data import *

example_path = r'./example_data/osm2graph_example/' # path to the osm example dir
tagw_path = r'all_tags30_frequency1.json'

osm_process = OSM2Graph(tagw_path, 'cuda:1')
name = 'a472140250_FI_21' # example name

# obatin the geographic boundingbox of the give RS image
# here, the boundingbox is saved in the meta file
# you can also calculate it from the original RS data like TIFF, tools are prepared in the prepare_data/utils.py
# note that the geographic coodinate system should be the same between RS and OSM data

with open(os.path.join(example_path, name, name+'.pickle'), 'rb') as file:
    meta_data = pickle.load(file)
bbox = meta_data['bbox']
north, south, east, west = bbox[3], bbox[1], bbox[2], bbox[0]
width = east - west
height = north - south

# OSM2Graph can handle situations where one or two vector types are missing
# If none of the three vector types are available, then just use RS image.
try:
    polygon_file = gpd.read_file(os.path.join(example_path, name, name + '_polygon.geojson'))
except:
    polygon_file = None
try:
    line_file = gpd.read_file(os.path.join(example_path, name, name + '_line.geojson'))
except:
    line_file = None
try:
    point_file = gpd.read_file(os.path.join(example_path, name, name + '_point.geojson'))
except:
    point_file = None

data = osm_process.process(polygon_file, line_file, point_file, north, south, east, west)