Skip to content
This repository has been archived by the owner on Oct 12, 2022. It is now read-only.

Commit

Permalink
#3377 Add database connection doesn't handle AAD
Browse files Browse the repository at this point in the history
  • Loading branch information
MikhailArkhipov committed Apr 4, 2017
1 parent e34d785 commit 92c5967
Show file tree
Hide file tree
Showing 12 changed files with 173 additions and 88 deletions.
31 changes: 16 additions & 15 deletions src/Common/Core/Impl/IO/FileSystem.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.IO.Compression;
using System.Runtime.InteropServices;
Expand All @@ -25,15 +26,15 @@ public sealed class FileSystem : IFileSystem {
}

public string ReadAllText(string path) => File.ReadAllText(path);

public void WriteAllText(string path, string content) => File.WriteAllText(path, content);

public IEnumerable<string> FileReadAllLines(string path) => File.ReadLines(path);

public void FileWriteAllLines(string path, IEnumerable<string> contents) => File.WriteAllLines(path, contents);

public byte[] FileReadAllBytes(string path) => File.ReadAllBytes(path);

public void FileWriteAllBytes(string path, byte[] bytes) => File.WriteAllBytes(path, bytes);

public Stream CreateFile(string path) => File.Create(path);
Expand All @@ -42,30 +43,30 @@ public sealed class FileSystem : IFileSystem {
public bool DirectoryExists(string path) => Directory.Exists(path);

public FileAttributes GetFileAttributes(string path) => File.GetAttributes(path);

public string ToLongPath(string path) {
var sb = new StringBuilder(NativeMethods.MAX_PATH);
NativeMethods.GetLongPathName(path, sb, sb.Capacity);
return sb.ToString();
}

public string ToShortPath(string path) {
var sb = new StringBuilder(NativeMethods.MAX_PATH);
NativeMethods.GetShortPathName(path, sb, sb.Capacity);
return sb.ToString();
}

public IFileVersionInfo GetVersionInfo(string path) {
var fvi = System.Diagnostics.FileVersionInfo.GetVersionInfo(path);
return new FileVersionInfo(fvi.FileMajorPart, fvi.FileMinorPart);
public Version GetFileVersion(string path) {
var fvi = FileVersionInfo.GetVersionInfo(path);
return new Version(fvi.FileMajorPart, fvi.FileMinorPart, fvi.FileBuildPart, fvi.FilePrivatePart);
}

public void DeleteFile(string path) => File.Delete(path);

public void DeleteDirectory(string path, bool recursive) => Directory.Delete(path, recursive);

public string[] GetFileSystemEntries(string path, string searchPattern, SearchOption options) => Directory.GetFileSystemEntries(path, searchPattern, options);

public void CreateDirectory(string path) => Directory.CreateDirectory(path);

public string CompressFile(string path, string relativeTodir) {
Expand All @@ -85,7 +86,7 @@ public sealed class FileSystem : IFileSystem {
string zipFilePath = Path.GetTempFileName();
using (FileStream zipStream = new FileStream(zipFilePath, FileMode.Create))
using (ZipArchive archive = new ZipArchive(zipStream, ZipArchiveMode.Create)) {
foreach(string path in paths) {
foreach (string path in paths) {
if (ct.IsCancellationRequested) {
break;
}
Expand All @@ -111,14 +112,14 @@ public sealed class FileSystem : IFileSystem {

public string CompressDirectory(string path, Matcher matcher, IProgress<string> progress, CancellationToken ct) {
string zipFilePath = Path.GetTempFileName();
using (FileStream zipStream = new FileStream(zipFilePath, FileMode.Create))
using (FileStream zipStream = new FileStream(zipFilePath, FileMode.Create))
using (ZipArchive archive = new ZipArchive(zipStream, ZipArchiveMode.Create)) {
Queue<string> dirs = new Queue<string>();
dirs.Enqueue(path);
while (dirs.Count > 0) {
var dir = dirs.Dequeue();
var subdirs = Directory.GetDirectories(dir);
foreach(var subdir in subdirs) {
foreach (var subdir in subdirs) {
dirs.Enqueue(subdir);
}

Expand Down Expand Up @@ -184,8 +185,8 @@ public enum KnownFolderflags : uint {
int nBufferLength);

[DllImport("Shell32.dll")]
public static extern int SHGetKnownFolderPath([MarshalAs(UnmanagedType.LPStruct)] Guid rfid,
uint dwFlags,
public static extern int SHGetKnownFolderPath([MarshalAs(UnmanagedType.LPStruct)] Guid rfid,
uint dwFlags,
IntPtr hToken,
out IntPtr ppszPath);
}
Expand Down
14 changes: 0 additions & 14 deletions src/Common/Core/Impl/IO/FileVersionInfo.cs

This file was deleted.

2 changes: 1 addition & 1 deletion src/Common/Core/Impl/IO/IFileSystem.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public interface IFileSystem {
Stream CreateFile(string path);
Stream FileOpen(string path, FileMode mode);

IFileVersionInfo GetVersionInfo(string path);
Version GetFileVersion(string path);
void DeleteFile(string path);
void DeleteDirectory(string path, bool recursive);
string[] GetFileSystemEntries(string path, string searchPattern, SearchOption options);
Expand Down
9 changes: 0 additions & 9 deletions src/Common/Core/Impl/IO/IFileVersionInfo.cs

This file was deleted.

2 changes: 0 additions & 2 deletions src/Common/Core/Impl/Microsoft.Common.Core.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,6 @@
<Compile Include="Disposables\DisposeToken.cs" />
<Compile Include="Extensions\CharExtensions.cs" />
<Compile Include="Extensions\IOExtensions.cs" />
<Compile Include="IO\FileVersionInfo.cs" />
<Compile Include="IO\IFileVersionInfo.cs" />
<Compile Include="OS\IProcessServices.cs" />
<Compile Include="OS\ProcessServices.cs" />
<Compile Include="Properties\AssemblyInfo.cs" />
Expand Down
9 changes: 9 additions & 0 deletions src/R/Components/Impl/Resources.Designer.cs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions src/R/Components/Impl/Resources.resx
Original file line number Diff line number Diff line change
Expand Up @@ -696,4 +696,7 @@ This prompt can be suppressed in R Tools | Options.
<data name="Error_CannotEditExpression" xml:space="preserve">
<value>Expression does not evaluate to function.</value>
</data>
<data name="Error_OdbcDriver" xml:space="preserve">
<value>Connection to SQL database in Azure requires SQL ODBC driver 13.1 or higher. Please install the latest Microsoft ODBC Driver from {0}.</value>
</data>
</root>
24 changes: 19 additions & 5 deletions src/R/Components/Impl/Sql/ConnectionStringConverter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
namespace Microsoft.R.Components.Sql {
public static class ConnectionStringConverter {
public const string OdbcSqlDriver = "{SQL Server}";
public const string OdbcSql13Driver = "{ODBC Driver 13 for SQL Server}";

public const string OdbcDriverKey = "Driver";
public const string OdbcServerKey = "Server";
Expand All @@ -27,8 +28,9 @@ public static class ConnectionStringConverter {
}
try {
var sql = new SqlConnectionStringBuilder(sqlClientString);
var sqlDriver = sql.Authentication == SqlAuthenticationMethod.ActiveDirectoryIntegrated ? OdbcSql13Driver : OdbcSqlDriver;
var odbc = new OdbcConnectionStringBuilder {
[OdbcDriverKey] = OdbcSqlDriver,
[OdbcDriverKey] = sqlDriver,
[OdbcServerKey] = sql.DataSource,
[OdbcDatabaseKey] = sql.InitialCatalog
};
Expand All @@ -44,8 +46,12 @@ public static class ConnectionStringConverter {
odbc[OdbcTrustServerCertificateKey] = "yes";
}

odbc[OdbcUidKey] = sql.UserID;
odbc[OdbcPasswordKey] = sql.Password;
if (!string.IsNullOrEmpty(sql.UserID)) {
odbc[OdbcUidKey] = sql.UserID;
}
if (!string.IsNullOrEmpty(sql.Password)) {
odbc[OdbcPasswordKey] = sql.Password;
}

return odbc.ConnectionString;
} catch (ArgumentException) { }
Expand All @@ -67,10 +73,18 @@ public static class ConnectionStringConverter {
var sql = new SqlConnectionStringBuilder {
DataSource = server,
InitialCatalog = database,
UserID = odbc.GetValue(OdbcUidKey),
Password = odbc.GetValue(OdbcPasswordKey)
};

var userId = odbc.GetValue(OdbcUidKey);
if(!string.IsNullOrEmpty(userId)) {
sql.UserID = odbc.GetValue(OdbcUidKey);
}

var password = odbc.GetValue(OdbcPasswordKey);
if (!string.IsNullOrEmpty(password)) {
sql.Password = password;
}

// If no password and user name, assume integrated authentication
sql.IntegratedSecurity = string.IsNullOrEmpty(sql.UserID) && string.IsNullOrEmpty(sql.Password);
sql.TrustServerCertificate = string.Compare(odbc.GetValue(OdbcTrustServerCertificateKey), "yes", StringComparison.OrdinalIgnoreCase) == 0;
Expand Down
48 changes: 44 additions & 4 deletions src/R/Components/Impl/Sql/DbConnectionService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,19 @@

using System;
using System.ComponentModel.Composition;
using System.Data.SqlClient;
using System.Globalization;
using System.Windows.Forms;
using Microsoft.Common.Core;
using Microsoft.Common.Core.Shell;
using Microsoft.Data.ConnectionUI;
using Microsoft.Win32;

namespace Microsoft.R.Components.Sql {
[Export(typeof(IDbConnectionService))]
internal sealed class DbConnectionService : IDbConnectionService {
private const string _defaultSqlConnectionString = "Data Source=(local);Integrated Security=true";
private const string DefaultSqlConnectionString = "Data Source=(local);Integrated Security=true";

private readonly ICoreShell _coreShell;
private string _odbcConnectionString;

Expand All @@ -22,7 +27,7 @@ internal sealed class DbConnectionService : IDbConnectionService {
public string EditConnectionString(string odbcConnectionString) {
var originalConnectionString = (odbcConnectionString.OdbcToSqlClient()
?? _odbcConnectionString.OdbcToSqlClient())
?? _defaultSqlConnectionString;
?? DefaultSqlConnectionString;
do {
using (var dlg = new DataConnectionDialog()) {
DataSource.AddStandardDataSources(dlg);
Expand All @@ -31,11 +36,15 @@ internal sealed class DbConnectionService : IDbConnectionService {
try {
dlg.ConnectionString = originalConnectionString;
var result = DataConnectionDialog.Show(dlg);
switch(result) {
switch (result) {
case DialogResult.Cancel:
return null;
case DialogResult.OK:
_odbcConnectionString = dlg.ConnectionString.SqlClientToOdbc();
var sqlString = dlg.ConnectionString;
if (IsSqlAADConnection(sqlString)) {
CheckSqlOdbcDriverVersion();
}
_odbcConnectionString = sqlString.SqlClientToOdbc();
break;
}
break;
Expand All @@ -50,5 +59,36 @@ internal sealed class DbConnectionService : IDbConnectionService {

return _odbcConnectionString;
}

private bool IsSqlAADConnection(string connectionString) {
var csb = new SqlConnectionStringBuilder(connectionString);
return csb.Authentication == SqlAuthenticationMethod.ActiveDirectoryIntegrated;
}

internal bool CheckSqlOdbcDriverVersion() {
using (var hklm = _coreShell.Services.Registry.OpenBaseKey(RegistryHive.LocalMachine, RegistryView.Registry64)) {
using (var odbcKey = hklm.OpenSubKey(@"SOFTWARE\ODBC\ODBCINST.INI\ODBC Driver 13 for SQL Server")) {
var driverPath = odbcKey.GetValue("Driver") as string;
if (!string.IsNullOrEmpty(driverPath)) {
var fs = _coreShell.Services.FileSystem;
if (fs.FileExists(driverPath)) {
var version = fs.GetFileVersion(driverPath);
if (version >= new Version("2015.131.4413.46")) {
return true;
}
}
}
}
}
var link = FormatLocalizedLink(_coreShell.AppConstants.LocaleId, "https://www.microsoft.com/{0}/download/details.aspx?id=53339");
_coreShell.ShowErrorMessage(Resources.Error_OdbcDriver.FormatInvariant(link));
_coreShell.Services.ProcessServices.Start(link);
return false;
}

private static string FormatLocalizedLink(uint localeId, string format) {
var culture = CultureInfo.GetCultureInfo((int)localeId);
return string.Format(CultureInfo.InvariantCulture, format, culture.Name);
}
}
}

0 comments on commit 92c5967

Please sign in to comment.